From 0079733bfd31716ba7fdd14027006d9761113a4f Mon Sep 17 00:00:00 2001 From: qingxu fu <505030475@qq.com> Date: Thu, 6 Apr 2023 18:29:49 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BB=E8=A6=81=E4=BB=A3=E7=A0=81=E8=A7=84?= =?UTF-8?q?=E6=95=B4=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- check_proxy.py | 18 +++--- config.py | 11 ++-- core_functional.py | 5 +- crazy_functional.py | 10 ++- theme.py | 24 ++++--- toolbox.py | 149 ++++++++++++++++++++++++++++++-------------- 6 files changed, 140 insertions(+), 77 deletions(-) diff --git a/check_proxy.py b/check_proxy.py index abc75d0..95a439e 100644 --- a/check_proxy.py +++ b/check_proxy.py @@ -3,7 +3,8 @@ def check_proxy(proxies): import requests proxies_https = proxies['https'] if proxies is not None else '无' try: - response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4) + response = requests.get("https://ipapi.co/json/", + proxies=proxies, timeout=4) data = response.json() print(f'查询代理的地理位置,返回的结果是{data}') if 'country_name' in data: @@ -21,9 +22,11 @@ def check_proxy(proxies): def auto_update(): from toolbox import get_conf - import requests, time, json + import requests + import time + import json proxies, = get_conf('proxies') - response = requests.get("https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version", + response = requests.get("https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version", proxies=proxies, timeout=1) remote_json_data = json.loads(response.text) remote_version = remote_json_data['version'] @@ -31,11 +34,12 @@ def auto_update(): new_feature = "新功能:" + remote_json_data["new_feature"] else: new_feature = "" - with open('./version', 'r', encoding='utf8') as f: + with open('./version', 'r', encoding='utf8') as f: current_version = f.read() current_version = json.loads(current_version)['version'] if (remote_version - current_version) >= 0.05: - print(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}') + print( + f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}') print('Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n') time.sleep(3) return @@ -44,8 +48,8 @@ def auto_update(): if __name__ == '__main__': - import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染 + import os + os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染 from toolbox import get_conf proxies, = get_conf('proxies') check_proxy(proxies) - \ No newline at end of file diff --git a/config.py b/config.py index 8c98657..f94f183 100644 --- a/config.py +++ b/config.py @@ -11,10 +11,10 @@ if USE_PROXY: # [端口] 在代理软件的设置里找。虽然不同的代理软件界面不一样,但端口号都应该在最显眼的位置上 # 代理网络的地址,打开你的科学上网软件查看代理的协议(socks5/http)、地址(localhost)和端口(11284) - proxies = { + proxies = { # [协议]:// [地址] :[端口] - "http": "socks5h://localhost:11284", - "https": "socks5h://localhost:11284", + "http": "socks5h://localhost:11284", + "https": "socks5h://localhost:11284", } else: proxies = None @@ -25,7 +25,7 @@ else: CHATBOT_HEIGHT = 1115 # 窗口布局 -LAYOUT = "LEFT-RIGHT" # "LEFT-RIGHT"(左右布局) # "TOP-DOWN"(上下布局) +LAYOUT = "LEFT-RIGHT" # "LEFT-RIGHT"(左右布局) # "TOP-DOWN"(上下布局) # 发送请求到OpenAI后,等待多久判定为超时 TIMEOUT_SECONDS = 25 @@ -46,4 +46,5 @@ API_URL = "https://api.openai.com/v1/chat/completions" CONCURRENT_COUNT = 100 # 设置用户名和密码(相关功能不稳定,与gradio版本和网络都相关,如果本地使用不建议加这个) -AUTHENTICATION = [] # [("username", "password"), ("username2", "password2"), ...] +# [("username", "password"), ("username2", "password2"), ...] +AUTHENTICATION = [] diff --git a/core_functional.py b/core_functional.py index 22d2c2b..722abc1 100644 --- a/core_functional.py +++ b/core_functional.py @@ -4,6 +4,7 @@ # 默认按钮颜色是 secondary from toolbox import clear_line_break + def get_core_functions(): return { "英语学术润色": { @@ -11,12 +12,12 @@ def get_core_functions(): "Prefix": r"Below is a paragraph from an academic paper. Polish the writing to meet the academic style, " + r"improve the spelling, grammar, clarity, concision and overall readability. When necessary, rewrite the whole sentence. " + r"Furthermore, list all modification and explain the reasons to do so in markdown table." + "\n\n", - # 后语 + # 后语 "Suffix": r"", "Color": r"secondary", # 按钮颜色 }, "中文学术润色": { - "Prefix": r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性," + + "Prefix": r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性," + r"同时分解长句,减少重复,并提供改进建议。请只提供文本的更正版本,避免包括解释。请编辑以下文本" + "\n\n", "Suffix": r"", }, diff --git a/crazy_functional.py b/crazy_functional.py index 44b0918..3e53f54 100644 --- a/crazy_functional.py +++ b/crazy_functional.py @@ -1,4 +1,5 @@ -from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效 +from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效 + def get_crazy_functions(): ###################### 第一组插件 ########################### @@ -81,7 +82,8 @@ def get_crazy_functions(): "[仅供开发调试] 批量总结PDF文档": { "Color": "stop", "AsButton": False, # 加入下拉菜单中 - "Function": HotReload(批量总结PDF文档) # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效 + # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效 + "Function": HotReload(批量总结PDF文档) }, "[仅供开发调试] 批量总结PDF文档pdfminer": { "Color": "stop", @@ -109,9 +111,5 @@ def get_crazy_functions(): except Exception as err: print(f'[下载arxiv论文并翻译摘要] 插件导入失败 {str(err)}') - - ###################### 第n组插件 ########################### return function_plugins - - diff --git a/theme.py b/theme.py index 4ddae5a..0c368c4 100644 --- a/theme.py +++ b/theme.py @@ -1,4 +1,4 @@ -import gradio as gr +import gradio as gr # gradio可用颜色列表 # gr.themes.utils.colors.slate (石板色) @@ -24,14 +24,16 @@ import gradio as gr # gr.themes.utils.colors.pink (粉红色) # gr.themes.utils.colors.rose (玫瑰色) + def adjust_theme(): - try: + try: color_er = gr.themes.utils.colors.fuchsia - set_theme = gr.themes.Default( - primary_hue=gr.themes.utils.colors.orange, - neutral_hue=gr.themes.utils.colors.gray, - font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui", "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")], - font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")]) + set_theme = gr.themes.Default( + primary_hue=gr.themes.utils.colors.orange, + neutral_hue=gr.themes.utils.colors.gray, + font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui", + "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")], + font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")]) set_theme.set( # Colors input_background_fill_dark="*neutral_800", @@ -77,10 +79,12 @@ def adjust_theme(): button_cancel_text_color=color_er.c600, button_cancel_text_color_dark="white", ) - except: - set_theme = None; print('gradio版本较旧, 不能自定义字体和颜色') + except: + set_theme = None + print('gradio版本较旧, 不能自定义字体和颜色') return set_theme + advanced_css = """ /* 设置表格的外边距为1em,内部单元格之间边框合并,空单元格显示. */ .markdown-body table { @@ -149,4 +153,4 @@ advanced_css = """ padding: 1em; margin: 1em 2em 1em 0.5em; } -""" \ No newline at end of file +""" diff --git a/toolbox.py b/toolbox.py index 6f89a30..0b1c85c 100644 --- a/toolbox.py +++ b/toolbox.py @@ -1,14 +1,23 @@ -import markdown, mdtex2html, threading, importlib, traceback, importlib, inspect, re +import markdown +import mdtex2html +import threading +import importlib +import traceback +import importlib +import inspect +import re from show_math import convert as convert_math from functools import wraps, lru_cache + def ArgsGeneralWrapper(f): """ 装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。 """ def decorated(txt, txt2, *args, **kwargs): txt_passon = txt - if txt == "" and txt2 != "": txt_passon = txt2 + if txt == "" and txt2 != "": + txt_passon = txt2 yield from f(txt_passon, *args, **kwargs) return decorated @@ -18,7 +27,7 @@ def get_reduce_token_percent(text): # text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens" pattern = r"(\d+)\s+tokens\b" match = re.findall(pattern, text) - EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题 + EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题 max_limit = float(match[0]) - EXCEED_ALLO current_tokens = float(match[1]) ratio = max_limit/current_tokens @@ -27,6 +36,7 @@ def get_reduce_token_percent(text): except: return 0.5, '不详' + def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt='', long_connection=True): """ 调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断 @@ -46,21 +56,26 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp # list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息 mutable = [None, ''] # multi-threading worker + def mt(i_say, history): while True: try: if long_connection: - mutable[0] = predict_no_ui_long_connection(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt) + mutable[0] = predict_no_ui_long_connection( + inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt) else: - mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt) + mutable[0] = predict_no_ui( + inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt) break except ConnectionAbortedError as token_exceeded_error: # 尝试计算比例,尽可能多地保留文本 - p_ratio, n_exceed = get_reduce_token_percent(str(token_exceeded_error)) + p_ratio, n_exceed = get_reduce_token_percent( + str(token_exceeded_error)) if len(history) > 0: - history = [his[ int(len(his) *p_ratio): ] for his in history if his is not None] + history = [his[int(len(his) * p_ratio):] + for his in history if his is not None] else: - i_say = i_say[: int(len(i_say) *p_ratio) ] + i_say = i_say[: int(len(i_say) * p_ratio)] mutable[1] = f'警告,文本过长将进行截断,Token溢出数:{n_exceed},截断比例:{(1-p_ratio):.0%}。' except TimeoutError as e: mutable[0] = '[Local Message] 请求超时。' @@ -69,42 +84,51 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp mutable[0] = f'[Local Message] 异常:{str(e)}.' raise RuntimeError(f'[Local Message] 异常:{str(e)}.') # 创建新线程发出http请求 - thread_name = threading.Thread(target=mt, args=(i_say, history)); thread_name.start() + thread_name = threading.Thread(target=mt, args=(i_say, history)) + thread_name.start() # 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成 cnt = 0 while thread_name.is_alive(): cnt += 1 - chatbot[-1] = (i_say_show_user, f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt%4))) + chatbot[-1] = (i_say_show_user, + f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt % 4))) yield chatbot, history, '正常' time.sleep(1) # 把gpt的输出从mutable中取出来 gpt_say = mutable[0] - if gpt_say=='[Local Message] Failed with timeout.': raise TimeoutError + if gpt_say == '[Local Message] Failed with timeout.': + raise TimeoutError return gpt_say + def write_results_to_file(history, file_name=None): """ 将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。 """ - import os, time + import os + import time if file_name is None: # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md' - file_name = 'chatGPT分析报告' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md' + file_name = 'chatGPT分析报告' + \ + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md' os.makedirs('./gpt_log/', exist_ok=True) - with open(f'./gpt_log/{file_name}', 'w', encoding = 'utf8') as f: + with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f: f.write('# chatGPT 分析报告\n') for i, content in enumerate(history): try: # 这个bug没找到触发条件,暂时先这样顶一下 - if type(content) != str: content = str(content) + if type(content) != str: + content = str(content) except: continue - if i%2==0: f.write('## ') + if i % 2 == 0: + f.write('## ') f.write(content) f.write('\n\n') res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}') print(res) return res + def regular_txt_to_markdown(text): """ 将普通文本转换为Markdown格式的文本。 @@ -114,6 +138,7 @@ def regular_txt_to_markdown(text): text = text.replace('\n\n\n', '\n\n') return text + def CatchException(f): """ 装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。 @@ -127,11 +152,14 @@ def CatchException(f): from toolbox import get_conf proxies, = get_conf('proxies') tb_str = '```\n' + traceback.format_exc() + '```' - if chatbot is None or len(chatbot) == 0: chatbot = [["插件调度异常","异常原因"]] - chatbot[-1] = (chatbot[-1][0], f"[Local Message] 实验性函数调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}") + if chatbot is None or len(chatbot) == 0: + chatbot = [["插件调度异常", "异常原因"]] + chatbot[-1] = (chatbot[-1][0], + f"[Local Message] 实验性函数调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}") yield chatbot, history, f'异常 {e}' return decorated + def HotReload(f): """ 装饰器函数,实现函数插件热更新 @@ -143,12 +171,15 @@ def HotReload(f): yield from f_hot_reload(*args, **kwargs) return decorated + def report_execption(chatbot, history, a, b): """ 向chatbot中添加错误信息 """ chatbot.append((a, b)) - history.append(a); history.append(b) + history.append(a) + history.append(b) + def text_divide_paragraph(text): """ @@ -165,6 +196,7 @@ def text_divide_paragraph(text): text = "".join(lines) return text + def markdown_convertion(txt): """ 将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。 @@ -172,16 +204,19 @@ def markdown_convertion(txt): pre = '