From 0079733bfd31716ba7fdd14027006d9761113a4f Mon Sep 17 00:00:00 2001 From: qingxu fu <505030475@qq.com> Date: Thu, 6 Apr 2023 18:29:49 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BB=E8=A6=81=E4=BB=A3=E7=A0=81=E8=A7=84?= =?UTF-8?q?=E6=95=B4=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- check_proxy.py | 18 +++--- config.py | 11 ++-- core_functional.py | 5 +- crazy_functional.py | 10 ++- theme.py | 24 ++++--- toolbox.py | 149 ++++++++++++++++++++++++++++++-------------- 6 files changed, 140 insertions(+), 77 deletions(-) diff --git a/check_proxy.py b/check_proxy.py index abc75d0..95a439e 100644 --- a/check_proxy.py +++ b/check_proxy.py @@ -3,7 +3,8 @@ def check_proxy(proxies): import requests proxies_https = proxies['https'] if proxies is not None else '无' try: - response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4) + response = requests.get("https://ipapi.co/json/", + proxies=proxies, timeout=4) data = response.json() print(f'查询代理的地理位置,返回的结果是{data}') if 'country_name' in data: @@ -21,9 +22,11 @@ def check_proxy(proxies): def auto_update(): from toolbox import get_conf - import requests, time, json + import requests + import time + import json proxies, = get_conf('proxies') - response = requests.get("https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version", + response = requests.get("https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version", proxies=proxies, timeout=1) remote_json_data = json.loads(response.text) remote_version = remote_json_data['version'] @@ -31,11 +34,12 @@ def auto_update(): new_feature = "新功能:" + remote_json_data["new_feature"] else: new_feature = "" - with open('./version', 'r', encoding='utf8') as f: + with open('./version', 'r', encoding='utf8') as f: current_version = f.read() current_version = json.loads(current_version)['version'] if (remote_version - current_version) >= 0.05: - print(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}') + print( + f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}') print('Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n') time.sleep(3) return @@ -44,8 +48,8 @@ def auto_update(): if __name__ == '__main__': - import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染 + import os + os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染 from toolbox import get_conf proxies, = get_conf('proxies') check_proxy(proxies) - \ No newline at end of file diff --git a/config.py b/config.py index 8c98657..f94f183 100644 --- a/config.py +++ b/config.py @@ -11,10 +11,10 @@ if USE_PROXY: # [端口] 在代理软件的设置里找。虽然不同的代理软件界面不一样,但端口号都应该在最显眼的位置上 # 代理网络的地址,打开你的科学上网软件查看代理的协议(socks5/http)、地址(localhost)和端口(11284) - proxies = { + proxies = { # [协议]:// [地址] :[端口] - "http": "socks5h://localhost:11284", - "https": "socks5h://localhost:11284", + "http": "socks5h://localhost:11284", + "https": "socks5h://localhost:11284", } else: proxies = None @@ -25,7 +25,7 @@ else: CHATBOT_HEIGHT = 1115 # 窗口布局 -LAYOUT = "LEFT-RIGHT" # "LEFT-RIGHT"(左右布局) # "TOP-DOWN"(上下布局) +LAYOUT = "LEFT-RIGHT" # "LEFT-RIGHT"(左右布局) # "TOP-DOWN"(上下布局) # 发送请求到OpenAI后,等待多久判定为超时 TIMEOUT_SECONDS = 25 @@ -46,4 +46,5 @@ API_URL = "https://api.openai.com/v1/chat/completions" CONCURRENT_COUNT = 100 # 设置用户名和密码(相关功能不稳定,与gradio版本和网络都相关,如果本地使用不建议加这个) -AUTHENTICATION = [] # [("username", "password"), ("username2", "password2"), ...] +# [("username", "password"), ("username2", "password2"), ...] +AUTHENTICATION = [] diff --git a/core_functional.py b/core_functional.py index 22d2c2b..722abc1 100644 --- a/core_functional.py +++ b/core_functional.py @@ -4,6 +4,7 @@ # 默认按钮颜色是 secondary from toolbox import clear_line_break + def get_core_functions(): return { "英语学术润色": { @@ -11,12 +12,12 @@ def get_core_functions(): "Prefix": r"Below is a paragraph from an academic paper. Polish the writing to meet the academic style, " + r"improve the spelling, grammar, clarity, concision and overall readability. When necessary, rewrite the whole sentence. " + r"Furthermore, list all modification and explain the reasons to do so in markdown table." + "\n\n", - # 后语 + # 后语 "Suffix": r"", "Color": r"secondary", # 按钮颜色 }, "中文学术润色": { - "Prefix": r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性," + + "Prefix": r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性," + r"同时分解长句,减少重复,并提供改进建议。请只提供文本的更正版本,避免包括解释。请编辑以下文本" + "\n\n", "Suffix": r"", }, diff --git a/crazy_functional.py b/crazy_functional.py index 44b0918..3e53f54 100644 --- a/crazy_functional.py +++ b/crazy_functional.py @@ -1,4 +1,5 @@ -from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效 +from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效 + def get_crazy_functions(): ###################### 第一组插件 ########################### @@ -81,7 +82,8 @@ def get_crazy_functions(): "[仅供开发调试] 批量总结PDF文档": { "Color": "stop", "AsButton": False, # 加入下拉菜单中 - "Function": HotReload(批量总结PDF文档) # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效 + # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效 + "Function": HotReload(批量总结PDF文档) }, "[仅供开发调试] 批量总结PDF文档pdfminer": { "Color": "stop", @@ -109,9 +111,5 @@ def get_crazy_functions(): except Exception as err: print(f'[下载arxiv论文并翻译摘要] 插件导入失败 {str(err)}') - - ###################### 第n组插件 ########################### return function_plugins - - diff --git a/theme.py b/theme.py index 4ddae5a..0c368c4 100644 --- a/theme.py +++ b/theme.py @@ -1,4 +1,4 @@ -import gradio as gr +import gradio as gr # gradio可用颜色列表 # gr.themes.utils.colors.slate (石板色) @@ -24,14 +24,16 @@ import gradio as gr # gr.themes.utils.colors.pink (粉红色) # gr.themes.utils.colors.rose (玫瑰色) + def adjust_theme(): - try: + try: color_er = gr.themes.utils.colors.fuchsia - set_theme = gr.themes.Default( - primary_hue=gr.themes.utils.colors.orange, - neutral_hue=gr.themes.utils.colors.gray, - font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui", "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")], - font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")]) + set_theme = gr.themes.Default( + primary_hue=gr.themes.utils.colors.orange, + neutral_hue=gr.themes.utils.colors.gray, + font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui", + "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")], + font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")]) set_theme.set( # Colors input_background_fill_dark="*neutral_800", @@ -77,10 +79,12 @@ def adjust_theme(): button_cancel_text_color=color_er.c600, button_cancel_text_color_dark="white", ) - except: - set_theme = None; print('gradio版本较旧, 不能自定义字体和颜色') + except: + set_theme = None + print('gradio版本较旧, 不能自定义字体和颜色') return set_theme + advanced_css = """ /* 设置表格的外边距为1em,内部单元格之间边框合并,空单元格显示. */ .markdown-body table { @@ -149,4 +153,4 @@ advanced_css = """ padding: 1em; margin: 1em 2em 1em 0.5em; } -""" \ No newline at end of file +""" diff --git a/toolbox.py b/toolbox.py index 6f89a30..0b1c85c 100644 --- a/toolbox.py +++ b/toolbox.py @@ -1,14 +1,23 @@ -import markdown, mdtex2html, threading, importlib, traceback, importlib, inspect, re +import markdown +import mdtex2html +import threading +import importlib +import traceback +import importlib +import inspect +import re from show_math import convert as convert_math from functools import wraps, lru_cache + def ArgsGeneralWrapper(f): """ 装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。 """ def decorated(txt, txt2, *args, **kwargs): txt_passon = txt - if txt == "" and txt2 != "": txt_passon = txt2 + if txt == "" and txt2 != "": + txt_passon = txt2 yield from f(txt_passon, *args, **kwargs) return decorated @@ -18,7 +27,7 @@ def get_reduce_token_percent(text): # text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens" pattern = r"(\d+)\s+tokens\b" match = re.findall(pattern, text) - EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题 + EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题 max_limit = float(match[0]) - EXCEED_ALLO current_tokens = float(match[1]) ratio = max_limit/current_tokens @@ -27,6 +36,7 @@ def get_reduce_token_percent(text): except: return 0.5, '不详' + def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt='', long_connection=True): """ 调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断 @@ -46,21 +56,26 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp # list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息 mutable = [None, ''] # multi-threading worker + def mt(i_say, history): while True: try: if long_connection: - mutable[0] = predict_no_ui_long_connection(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt) + mutable[0] = predict_no_ui_long_connection( + inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt) else: - mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt) + mutable[0] = predict_no_ui( + inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt) break except ConnectionAbortedError as token_exceeded_error: # 尝试计算比例,尽可能多地保留文本 - p_ratio, n_exceed = get_reduce_token_percent(str(token_exceeded_error)) + p_ratio, n_exceed = get_reduce_token_percent( + str(token_exceeded_error)) if len(history) > 0: - history = [his[ int(len(his) *p_ratio): ] for his in history if his is not None] + history = [his[int(len(his) * p_ratio):] + for his in history if his is not None] else: - i_say = i_say[: int(len(i_say) *p_ratio) ] + i_say = i_say[: int(len(i_say) * p_ratio)] mutable[1] = f'警告,文本过长将进行截断,Token溢出数:{n_exceed},截断比例:{(1-p_ratio):.0%}。' except TimeoutError as e: mutable[0] = '[Local Message] 请求超时。' @@ -69,42 +84,51 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp mutable[0] = f'[Local Message] 异常:{str(e)}.' raise RuntimeError(f'[Local Message] 异常:{str(e)}.') # 创建新线程发出http请求 - thread_name = threading.Thread(target=mt, args=(i_say, history)); thread_name.start() + thread_name = threading.Thread(target=mt, args=(i_say, history)) + thread_name.start() # 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成 cnt = 0 while thread_name.is_alive(): cnt += 1 - chatbot[-1] = (i_say_show_user, f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt%4))) + chatbot[-1] = (i_say_show_user, + f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt % 4))) yield chatbot, history, '正常' time.sleep(1) # 把gpt的输出从mutable中取出来 gpt_say = mutable[0] - if gpt_say=='[Local Message] Failed with timeout.': raise TimeoutError + if gpt_say == '[Local Message] Failed with timeout.': + raise TimeoutError return gpt_say + def write_results_to_file(history, file_name=None): """ 将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。 """ - import os, time + import os + import time if file_name is None: # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md' - file_name = 'chatGPT分析报告' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md' + file_name = 'chatGPT分析报告' + \ + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md' os.makedirs('./gpt_log/', exist_ok=True) - with open(f'./gpt_log/{file_name}', 'w', encoding = 'utf8') as f: + with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f: f.write('# chatGPT 分析报告\n') for i, content in enumerate(history): try: # 这个bug没找到触发条件,暂时先这样顶一下 - if type(content) != str: content = str(content) + if type(content) != str: + content = str(content) except: continue - if i%2==0: f.write('## ') + if i % 2 == 0: + f.write('## ') f.write(content) f.write('\n\n') res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}') print(res) return res + def regular_txt_to_markdown(text): """ 将普通文本转换为Markdown格式的文本。 @@ -114,6 +138,7 @@ def regular_txt_to_markdown(text): text = text.replace('\n\n\n', '\n\n') return text + def CatchException(f): """ 装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。 @@ -127,11 +152,14 @@ def CatchException(f): from toolbox import get_conf proxies, = get_conf('proxies') tb_str = '```\n' + traceback.format_exc() + '```' - if chatbot is None or len(chatbot) == 0: chatbot = [["插件调度异常","异常原因"]] - chatbot[-1] = (chatbot[-1][0], f"[Local Message] 实验性函数调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}") + if chatbot is None or len(chatbot) == 0: + chatbot = [["插件调度异常", "异常原因"]] + chatbot[-1] = (chatbot[-1][0], + f"[Local Message] 实验性函数调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}") yield chatbot, history, f'异常 {e}' return decorated + def HotReload(f): """ 装饰器函数,实现函数插件热更新 @@ -143,12 +171,15 @@ def HotReload(f): yield from f_hot_reload(*args, **kwargs) return decorated + def report_execption(chatbot, history, a, b): """ 向chatbot中添加错误信息 """ chatbot.append((a, b)) - history.append(a); history.append(b) + history.append(a) + history.append(b) + def text_divide_paragraph(text): """ @@ -165,6 +196,7 @@ def text_divide_paragraph(text): text = "
".join(lines) return text + def markdown_convertion(txt): """ 将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。 @@ -172,16 +204,19 @@ def markdown_convertion(txt): pre = '
' suf = '
' if ('$' in txt) and ('```' not in txt): - return pre + markdown.markdown(txt,extensions=['fenced_code','tables']) + '

' + markdown.markdown(convert_math(txt, splitParagraphs=False),extensions=['fenced_code','tables']) + suf + return pre + markdown.markdown(txt, extensions=['fenced_code', 'tables']) + '

' + markdown.markdown(convert_math(txt, splitParagraphs=False), extensions=['fenced_code', 'tables']) + suf else: - return pre + markdown.markdown(txt,extensions=['fenced_code','tables']) + suf + return pre + markdown.markdown(txt, extensions=['fenced_code', 'tables']) + suf + def close_up_code_segment_during_stream(gpt_reply): """ 在gpt输出代码的中途(输出了前面的```,但还没输出完后面的```),补上后面的``` """ - if '```' not in gpt_reply: return gpt_reply - if gpt_reply.endswith('```'): return gpt_reply + if '```' not in gpt_reply: + return gpt_reply + if gpt_reply.endswith('```'): + return gpt_reply # 排除了以上两个情况,我们 segments = gpt_reply.split('```') @@ -191,19 +226,21 @@ def close_up_code_segment_during_stream(gpt_reply): return gpt_reply+'\n```' else: return gpt_reply - def format_io(self, y): """ 将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。 """ - if y is None or y == []: return [] + if y is None or y == []: + return [] i_ask, gpt_reply = y[-1] - i_ask = text_divide_paragraph(i_ask) # 输入部分太自由,预处理一波 - gpt_reply = close_up_code_segment_during_stream(gpt_reply) # 当代码输出半截的时候,试着补上后个``` + i_ask = text_divide_paragraph(i_ask) # 输入部分太自由,预处理一波 + gpt_reply = close_up_code_segment_during_stream( + gpt_reply) # 当代码输出半截的时候,试着补上后个``` y[-1] = ( - None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code','tables']), + None if i_ask is None else markdown.markdown( + i_ask, extensions=['fenced_code', 'tables']), None if gpt_reply is None else markdown_convertion(gpt_reply) ) return y @@ -265,6 +302,7 @@ def extract_archive(file_path, dest_dir): return '' return '' + def find_recent_files(directory): """ me: find files that is created with in one minutes under a directory with python, write a function @@ -278,21 +316,29 @@ def find_recent_files(directory): for filename in os.listdir(directory): file_path = os.path.join(directory, filename) - if file_path.endswith('.log'): continue + if file_path.endswith('.log'): + continue created_time = os.path.getmtime(file_path) if created_time >= one_minute_ago: - if os.path.isdir(file_path): continue + if os.path.isdir(file_path): + continue recent_files.append(file_path) return recent_files def on_file_uploaded(files, chatbot, txt): - if len(files) == 0: return chatbot, txt - import shutil, os, time, glob + if len(files) == 0: + return chatbot, txt + import shutil + import os + import time + import glob from toolbox import extract_archive - try: shutil.rmtree('./private_upload/') - except: pass + try: + shutil.rmtree('./private_upload/') + except: + pass time_tag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) os.makedirs(f'private_upload/{time_tag}', exist_ok=True) err_msg = '' @@ -300,13 +346,14 @@ def on_file_uploaded(files, chatbot, txt): file_origin_name = os.path.basename(file.orig_name) shutil.copy(file.name, f'private_upload/{time_tag}/{file_origin_name}') err_msg += extract_archive(f'private_upload/{time_tag}/{file_origin_name}', - dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract') - moved_files = [fp for fp in glob.glob('private_upload/**/*', recursive=True)] + dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract') + moved_files = [fp for fp in glob.glob( + 'private_upload/**/*', recursive=True)] txt = f'private_upload/{time_tag}' moved_files_str = '\t\n\n'.join(moved_files) chatbot.append(['我上传了文件,请查收', - f'[Local Message] 收到以下文件: \n\n{moved_files_str}'+ - f'\n\n调用路径参数已自动修正到: \n\n{txt}'+ + f'[Local Message] 收到以下文件: \n\n{moved_files_str}' + + f'\n\n调用路径参数已自动修正到: \n\n{txt}' + f'\n\n现在您点击任意实验功能时,以上文件将被作为输入参数'+err_msg]) return chatbot, txt @@ -314,32 +361,37 @@ def on_file_uploaded(files, chatbot, txt): def on_report_generated(files, chatbot): from toolbox import find_recent_files report_files = find_recent_files('gpt_log') - if len(report_files) == 0: return None, chatbot + if len(report_files) == 0: + return None, chatbot # files.extend(report_files) chatbot.append(['汇总报告如何远程获取?', '汇总报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。']) return report_files, chatbot + @lru_cache(maxsize=128) def read_single_conf_with_lru_cache(arg): - try: r = getattr(importlib.import_module('config_private'), arg) - except: r = getattr(importlib.import_module('config'), arg) + try: + r = getattr(importlib.import_module('config_private'), arg) + except: + r = getattr(importlib.import_module('config'), arg) # 在读取API_KEY时,检查一下是不是忘了改config - if arg=='API_KEY': + if arg == 'API_KEY': # 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合 API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", r) if API_MATCH: print(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功") else: assert False, "正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \ - "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)" - if arg=='proxies': - if r is None: + "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)" + if arg == 'proxies': + if r is None: print('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问。建议:检查USE_PROXY选项是否修改。') - else: + else: print('[PROXY] 网络代理状态:已配置。配置信息如下:', r) assert isinstance(r, dict), 'proxies格式错误,请注意proxies选项的格式,不要遗漏括号。' return r + def get_conf(*args): # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到 res = [] @@ -348,14 +400,17 @@ def get_conf(*args): res.append(r) return res + def clear_line_break(txt): txt = txt.replace('\n', ' ') txt = txt.replace(' ', ' ') txt = txt.replace(' ', ' ') return txt + class DummyWith(): def __enter__(self): return self + def __exit__(self, exc_type, exc_value, traceback): - return \ No newline at end of file + return