主要代码规整化
This commit is contained in:
		
							parent
							
								
									1055fdaab7
								
							
						
					
					
						commit
						0079733bfd
					
				@ -3,7 +3,8 @@ def check_proxy(proxies):
 | 
				
			|||||||
    import requests
 | 
					    import requests
 | 
				
			||||||
    proxies_https = proxies['https'] if proxies is not None else '无'
 | 
					    proxies_https = proxies['https'] if proxies is not None else '无'
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4)
 | 
					        response = requests.get("https://ipapi.co/json/",
 | 
				
			||||||
 | 
					                                proxies=proxies, timeout=4)
 | 
				
			||||||
        data = response.json()
 | 
					        data = response.json()
 | 
				
			||||||
        print(f'查询代理的地理位置,返回的结果是{data}')
 | 
					        print(f'查询代理的地理位置,返回的结果是{data}')
 | 
				
			||||||
        if 'country_name' in data:
 | 
					        if 'country_name' in data:
 | 
				
			||||||
@ -21,7 +22,9 @@ def check_proxy(proxies):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def auto_update():
 | 
					def auto_update():
 | 
				
			||||||
    from toolbox import get_conf
 | 
					    from toolbox import get_conf
 | 
				
			||||||
    import requests, time, json
 | 
					    import requests
 | 
				
			||||||
 | 
					    import time
 | 
				
			||||||
 | 
					    import json
 | 
				
			||||||
    proxies, = get_conf('proxies')
 | 
					    proxies, = get_conf('proxies')
 | 
				
			||||||
    response = requests.get("https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version",
 | 
					    response = requests.get("https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version",
 | 
				
			||||||
                            proxies=proxies, timeout=1)
 | 
					                            proxies=proxies, timeout=1)
 | 
				
			||||||
@ -35,7 +38,8 @@ def auto_update():
 | 
				
			|||||||
        current_version = f.read()
 | 
					        current_version = f.read()
 | 
				
			||||||
        current_version = json.loads(current_version)['version']
 | 
					        current_version = json.loads(current_version)['version']
 | 
				
			||||||
    if (remote_version - current_version) >= 0.05:
 | 
					    if (remote_version - current_version) >= 0.05:
 | 
				
			||||||
        print(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}')
 | 
					        print(
 | 
				
			||||||
 | 
					            f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}')
 | 
				
			||||||
        print('Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
 | 
					        print('Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
 | 
				
			||||||
        time.sleep(3)
 | 
					        time.sleep(3)
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
@ -44,8 +48,8 @@ def auto_update():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
 | 
					    import os
 | 
				
			||||||
 | 
					    os.environ['no_proxy'] = '*'  # 避免代理网络产生意外污染
 | 
				
			||||||
    from toolbox import get_conf
 | 
					    from toolbox import get_conf
 | 
				
			||||||
    proxies, = get_conf('proxies')
 | 
					    proxies, = get_conf('proxies')
 | 
				
			||||||
    check_proxy(proxies)
 | 
					    check_proxy(proxies)
 | 
				
			||||||
    
 | 
					 | 
				
			||||||
@ -46,4 +46,5 @@ API_URL = "https://api.openai.com/v1/chat/completions"
 | 
				
			|||||||
CONCURRENT_COUNT = 100
 | 
					CONCURRENT_COUNT = 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 设置用户名和密码(相关功能不稳定,与gradio版本和网络都相关,如果本地使用不建议加这个)
 | 
					# 设置用户名和密码(相关功能不稳定,与gradio版本和网络都相关,如果本地使用不建议加这个)
 | 
				
			||||||
AUTHENTICATION = [] # [("username", "password"), ("username2", "password2"), ...]
 | 
					# [("username", "password"), ("username2", "password2"), ...]
 | 
				
			||||||
 | 
					AUTHENTICATION = []
 | 
				
			||||||
 | 
				
			|||||||
@ -4,6 +4,7 @@
 | 
				
			|||||||
# 默认按钮颜色是 secondary
 | 
					# 默认按钮颜色是 secondary
 | 
				
			||||||
from toolbox import clear_line_break
 | 
					from toolbox import clear_line_break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_core_functions():
 | 
					def get_core_functions():
 | 
				
			||||||
    return {
 | 
					    return {
 | 
				
			||||||
        "英语学术润色": {
 | 
					        "英语学术润色": {
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,6 @@
 | 
				
			|||||||
from toolbox import HotReload  # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
 | 
					from toolbox import HotReload  # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_crazy_functions():
 | 
					def get_crazy_functions():
 | 
				
			||||||
    ###################### 第一组插件 ###########################
 | 
					    ###################### 第一组插件 ###########################
 | 
				
			||||||
    # [第一组插件]: 最早期编写的项目插件和一些demo
 | 
					    # [第一组插件]: 最早期编写的项目插件和一些demo
 | 
				
			||||||
@ -81,7 +82,8 @@ def get_crazy_functions():
 | 
				
			|||||||
        "[仅供开发调试] 批量总结PDF文档": {
 | 
					        "[仅供开发调试] 批量总结PDF文档": {
 | 
				
			||||||
            "Color": "stop",
 | 
					            "Color": "stop",
 | 
				
			||||||
            "AsButton": False,  # 加入下拉菜单中
 | 
					            "AsButton": False,  # 加入下拉菜单中
 | 
				
			||||||
            "Function": HotReload(批量总结PDF文档) # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效
 | 
					            # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效
 | 
				
			||||||
 | 
					            "Function": HotReload(批量总结PDF文档)
 | 
				
			||||||
        },
 | 
					        },
 | 
				
			||||||
        "[仅供开发调试] 批量总结PDF文档pdfminer": {
 | 
					        "[仅供开发调试] 批量总结PDF文档pdfminer": {
 | 
				
			||||||
            "Color": "stop",
 | 
					            "Color": "stop",
 | 
				
			||||||
@ -109,9 +111,5 @@ def get_crazy_functions():
 | 
				
			|||||||
    except Exception as err:
 | 
					    except Exception as err:
 | 
				
			||||||
        print(f'[下载arxiv论文并翻译摘要] 插件导入失败 {str(err)}')
 | 
					        print(f'[下载arxiv论文并翻译摘要] 插件导入失败 {str(err)}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ###################### 第n组插件 ###########################
 | 
					    ###################### 第n组插件 ###########################
 | 
				
			||||||
    return function_plugins
 | 
					    return function_plugins
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										8
									
								
								theme.py
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								theme.py
									
									
									
									
									
								
							@ -24,13 +24,15 @@ import gradio as gr
 | 
				
			|||||||
# gr.themes.utils.colors.pink (粉红色)
 | 
					# gr.themes.utils.colors.pink (粉红色)
 | 
				
			||||||
# gr.themes.utils.colors.rose (玫瑰色)
 | 
					# gr.themes.utils.colors.rose (玫瑰色)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def adjust_theme():
 | 
					def adjust_theme():
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        color_er = gr.themes.utils.colors.fuchsia
 | 
					        color_er = gr.themes.utils.colors.fuchsia
 | 
				
			||||||
        set_theme = gr.themes.Default(
 | 
					        set_theme = gr.themes.Default(
 | 
				
			||||||
            primary_hue=gr.themes.utils.colors.orange,
 | 
					            primary_hue=gr.themes.utils.colors.orange,
 | 
				
			||||||
            neutral_hue=gr.themes.utils.colors.gray,
 | 
					            neutral_hue=gr.themes.utils.colors.gray,
 | 
				
			||||||
                        font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui", "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")], 
 | 
					            font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui",
 | 
				
			||||||
 | 
					                  "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")],
 | 
				
			||||||
            font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")])
 | 
					            font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")])
 | 
				
			||||||
        set_theme.set(
 | 
					        set_theme.set(
 | 
				
			||||||
            # Colors
 | 
					            # Colors
 | 
				
			||||||
@ -78,9 +80,11 @@ def adjust_theme():
 | 
				
			|||||||
            button_cancel_text_color_dark="white",
 | 
					            button_cancel_text_color_dark="white",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    except:
 | 
					    except:
 | 
				
			||||||
        set_theme = None; print('gradio版本较旧, 不能自定义字体和颜色')
 | 
					        set_theme = None
 | 
				
			||||||
 | 
					        print('gradio版本较旧, 不能自定义字体和颜色')
 | 
				
			||||||
    return set_theme
 | 
					    return set_theme
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
advanced_css = """
 | 
					advanced_css = """
 | 
				
			||||||
/* 设置表格的外边距为1em,内部单元格之间边框合并,空单元格显示. */
 | 
					/* 设置表格的外边距为1em,内部单元格之间边框合并,空单元格显示. */
 | 
				
			||||||
.markdown-body table {
 | 
					.markdown-body table {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										135
									
								
								toolbox.py
									
									
									
									
									
								
							
							
						
						
									
										135
									
								
								toolbox.py
									
									
									
									
									
								
							@ -1,14 +1,23 @@
 | 
				
			|||||||
import markdown, mdtex2html, threading, importlib, traceback, importlib, inspect, re
 | 
					import markdown
 | 
				
			||||||
 | 
					import mdtex2html
 | 
				
			||||||
 | 
					import threading
 | 
				
			||||||
 | 
					import importlib
 | 
				
			||||||
 | 
					import traceback
 | 
				
			||||||
 | 
					import importlib
 | 
				
			||||||
 | 
					import inspect
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
from show_math import convert as convert_math
 | 
					from show_math import convert as convert_math
 | 
				
			||||||
from functools import wraps, lru_cache
 | 
					from functools import wraps, lru_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ArgsGeneralWrapper(f):
 | 
					def ArgsGeneralWrapper(f):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
 | 
					        装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def decorated(txt, txt2, *args, **kwargs):
 | 
					    def decorated(txt, txt2, *args, **kwargs):
 | 
				
			||||||
        txt_passon = txt
 | 
					        txt_passon = txt
 | 
				
			||||||
        if txt == "" and txt2 != "": txt_passon = txt2
 | 
					        if txt == "" and txt2 != "":
 | 
				
			||||||
 | 
					            txt_passon = txt2
 | 
				
			||||||
        yield from f(txt_passon, *args, **kwargs)
 | 
					        yield from f(txt_passon, *args, **kwargs)
 | 
				
			||||||
    return decorated
 | 
					    return decorated
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -27,6 +36,7 @@ def get_reduce_token_percent(text):
 | 
				
			|||||||
    except:
 | 
					    except:
 | 
				
			||||||
        return 0.5, '不详'
 | 
					        return 0.5, '不详'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt='', long_connection=True):
 | 
					def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt='', long_connection=True):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断
 | 
					        调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断
 | 
				
			||||||
@ -46,21 +56,26 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp
 | 
				
			|||||||
    # list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息
 | 
					    # list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息
 | 
				
			||||||
    mutable = [None, '']
 | 
					    mutable = [None, '']
 | 
				
			||||||
    # multi-threading worker
 | 
					    # multi-threading worker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def mt(i_say, history):
 | 
					    def mt(i_say, history):
 | 
				
			||||||
        while True:
 | 
					        while True:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                if long_connection:
 | 
					                if long_connection:
 | 
				
			||||||
                    mutable[0] = predict_no_ui_long_connection(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
 | 
					                    mutable[0] = predict_no_ui_long_connection(
 | 
				
			||||||
 | 
					                        inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
 | 
					                    mutable[0] = predict_no_ui(
 | 
				
			||||||
 | 
					                        inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
            except ConnectionAbortedError as token_exceeded_error:
 | 
					            except ConnectionAbortedError as token_exceeded_error:
 | 
				
			||||||
                # 尝试计算比例,尽可能多地保留文本
 | 
					                # 尝试计算比例,尽可能多地保留文本
 | 
				
			||||||
                p_ratio, n_exceed = get_reduce_token_percent(str(token_exceeded_error))
 | 
					                p_ratio, n_exceed = get_reduce_token_percent(
 | 
				
			||||||
 | 
					                    str(token_exceeded_error))
 | 
				
			||||||
                if len(history) > 0:
 | 
					                if len(history) > 0:
 | 
				
			||||||
                    history = [his[     int(len(his)    *p_ratio):      ] for his in history if his is not None]
 | 
					                    history = [his[int(len(his) * p_ratio):]
 | 
				
			||||||
 | 
					                               for his in history if his is not None]
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    i_say = i_say[:     int(len(i_say)  *p_ratio)     ]
 | 
					                    i_say = i_say[:     int(len(i_say) * p_ratio)]
 | 
				
			||||||
                mutable[1] = f'警告,文本过长将进行截断,Token溢出数:{n_exceed},截断比例:{(1-p_ratio):.0%}。'
 | 
					                mutable[1] = f'警告,文本过长将进行截断,Token溢出数:{n_exceed},截断比例:{(1-p_ratio):.0%}。'
 | 
				
			||||||
            except TimeoutError as e:
 | 
					            except TimeoutError as e:
 | 
				
			||||||
                mutable[0] = '[Local Message] 请求超时。'
 | 
					                mutable[0] = '[Local Message] 请求超时。'
 | 
				
			||||||
@ -69,42 +84,51 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp
 | 
				
			|||||||
                mutable[0] = f'[Local Message] 异常:{str(e)}.'
 | 
					                mutable[0] = f'[Local Message] 异常:{str(e)}.'
 | 
				
			||||||
                raise RuntimeError(f'[Local Message] 异常:{str(e)}.')
 | 
					                raise RuntimeError(f'[Local Message] 异常:{str(e)}.')
 | 
				
			||||||
    # 创建新线程发出http请求
 | 
					    # 创建新线程发出http请求
 | 
				
			||||||
    thread_name = threading.Thread(target=mt, args=(i_say, history)); thread_name.start()
 | 
					    thread_name = threading.Thread(target=mt, args=(i_say, history))
 | 
				
			||||||
 | 
					    thread_name.start()
 | 
				
			||||||
    # 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成
 | 
					    # 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成
 | 
				
			||||||
    cnt = 0
 | 
					    cnt = 0
 | 
				
			||||||
    while thread_name.is_alive():
 | 
					    while thread_name.is_alive():
 | 
				
			||||||
        cnt += 1
 | 
					        cnt += 1
 | 
				
			||||||
        chatbot[-1] = (i_say_show_user, f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt%4)))
 | 
					        chatbot[-1] = (i_say_show_user,
 | 
				
			||||||
 | 
					                       f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt % 4)))
 | 
				
			||||||
        yield chatbot, history, '正常'
 | 
					        yield chatbot, history, '正常'
 | 
				
			||||||
        time.sleep(1)
 | 
					        time.sleep(1)
 | 
				
			||||||
    # 把gpt的输出从mutable中取出来
 | 
					    # 把gpt的输出从mutable中取出来
 | 
				
			||||||
    gpt_say = mutable[0]
 | 
					    gpt_say = mutable[0]
 | 
				
			||||||
    if gpt_say=='[Local Message] Failed with timeout.': raise TimeoutError
 | 
					    if gpt_say == '[Local Message] Failed with timeout.':
 | 
				
			||||||
 | 
					        raise TimeoutError
 | 
				
			||||||
    return gpt_say
 | 
					    return gpt_say
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def write_results_to_file(history, file_name=None):
 | 
					def write_results_to_file(history, file_name=None):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
 | 
					        将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    import os, time
 | 
					    import os
 | 
				
			||||||
 | 
					    import time
 | 
				
			||||||
    if file_name is None:
 | 
					    if file_name is None:
 | 
				
			||||||
        # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
					        # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
				
			||||||
        file_name = 'chatGPT分析报告' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
					        file_name = 'chatGPT分析报告' + \
 | 
				
			||||||
 | 
					            time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
				
			||||||
    os.makedirs('./gpt_log/', exist_ok=True)
 | 
					    os.makedirs('./gpt_log/', exist_ok=True)
 | 
				
			||||||
    with open(f'./gpt_log/{file_name}', 'w', encoding = 'utf8') as f:
 | 
					    with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
 | 
				
			||||||
        f.write('# chatGPT 分析报告\n')
 | 
					        f.write('# chatGPT 分析报告\n')
 | 
				
			||||||
        for i, content in enumerate(history):
 | 
					        for i, content in enumerate(history):
 | 
				
			||||||
            try:    # 这个bug没找到触发条件,暂时先这样顶一下
 | 
					            try:    # 这个bug没找到触发条件,暂时先这样顶一下
 | 
				
			||||||
                if type(content) != str: content = str(content)
 | 
					                if type(content) != str:
 | 
				
			||||||
 | 
					                    content = str(content)
 | 
				
			||||||
            except:
 | 
					            except:
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
            if i%2==0: f.write('## ')
 | 
					            if i % 2 == 0:
 | 
				
			||||||
 | 
					                f.write('## ')
 | 
				
			||||||
            f.write(content)
 | 
					            f.write(content)
 | 
				
			||||||
            f.write('\n\n')
 | 
					            f.write('\n\n')
 | 
				
			||||||
    res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
 | 
					    res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
 | 
				
			||||||
    print(res)
 | 
					    print(res)
 | 
				
			||||||
    return res
 | 
					    return res
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def regular_txt_to_markdown(text):
 | 
					def regular_txt_to_markdown(text):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        将普通文本转换为Markdown格式的文本。
 | 
					        将普通文本转换为Markdown格式的文本。
 | 
				
			||||||
@ -114,6 +138,7 @@ def regular_txt_to_markdown(text):
 | 
				
			|||||||
    text = text.replace('\n\n\n', '\n\n')
 | 
					    text = text.replace('\n\n\n', '\n\n')
 | 
				
			||||||
    return text
 | 
					    return text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def CatchException(f):
 | 
					def CatchException(f):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
 | 
					        装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
 | 
				
			||||||
@ -127,11 +152,14 @@ def CatchException(f):
 | 
				
			|||||||
            from toolbox import get_conf
 | 
					            from toolbox import get_conf
 | 
				
			||||||
            proxies, = get_conf('proxies')
 | 
					            proxies, = get_conf('proxies')
 | 
				
			||||||
            tb_str = '```\n' + traceback.format_exc() + '```'
 | 
					            tb_str = '```\n' + traceback.format_exc() + '```'
 | 
				
			||||||
            if chatbot is None or len(chatbot) == 0: chatbot = [["插件调度异常","异常原因"]]
 | 
					            if chatbot is None or len(chatbot) == 0:
 | 
				
			||||||
            chatbot[-1] = (chatbot[-1][0], f"[Local Message] 实验性函数调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
 | 
					                chatbot = [["插件调度异常", "异常原因"]]
 | 
				
			||||||
 | 
					            chatbot[-1] = (chatbot[-1][0],
 | 
				
			||||||
 | 
					                           f"[Local Message] 实验性函数调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
 | 
				
			||||||
            yield chatbot, history, f'异常 {e}'
 | 
					            yield chatbot, history, f'异常 {e}'
 | 
				
			||||||
    return decorated
 | 
					    return decorated
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def HotReload(f):
 | 
					def HotReload(f):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        装饰器函数,实现函数插件热更新
 | 
					        装饰器函数,实现函数插件热更新
 | 
				
			||||||
@ -143,12 +171,15 @@ def HotReload(f):
 | 
				
			|||||||
        yield from f_hot_reload(*args, **kwargs)
 | 
					        yield from f_hot_reload(*args, **kwargs)
 | 
				
			||||||
    return decorated
 | 
					    return decorated
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def report_execption(chatbot, history, a, b):
 | 
					def report_execption(chatbot, history, a, b):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        向chatbot中添加错误信息
 | 
					        向chatbot中添加错误信息
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    chatbot.append((a, b))
 | 
					    chatbot.append((a, b))
 | 
				
			||||||
    history.append(a); history.append(b)
 | 
					    history.append(a)
 | 
				
			||||||
 | 
					    history.append(b)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def text_divide_paragraph(text):
 | 
					def text_divide_paragraph(text):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@ -165,6 +196,7 @@ def text_divide_paragraph(text):
 | 
				
			|||||||
        text = "</br>".join(lines)
 | 
					        text = "</br>".join(lines)
 | 
				
			||||||
        return text
 | 
					        return text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def markdown_convertion(txt):
 | 
					def markdown_convertion(txt):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
 | 
					        将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
 | 
				
			||||||
@ -172,16 +204,19 @@ def markdown_convertion(txt):
 | 
				
			|||||||
    pre = '<div class="markdown-body">'
 | 
					    pre = '<div class="markdown-body">'
 | 
				
			||||||
    suf = '</div>'
 | 
					    suf = '</div>'
 | 
				
			||||||
    if ('$' in txt) and ('```' not in txt):
 | 
					    if ('$' in txt) and ('```' not in txt):
 | 
				
			||||||
        return pre + markdown.markdown(txt,extensions=['fenced_code','tables']) + '<br><br>' + markdown.markdown(convert_math(txt, splitParagraphs=False),extensions=['fenced_code','tables']) + suf
 | 
					        return pre + markdown.markdown(txt, extensions=['fenced_code', 'tables']) + '<br><br>' + markdown.markdown(convert_math(txt, splitParagraphs=False), extensions=['fenced_code', 'tables']) + suf
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        return pre + markdown.markdown(txt,extensions=['fenced_code','tables']) + suf
 | 
					        return pre + markdown.markdown(txt, extensions=['fenced_code', 'tables']) + suf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def close_up_code_segment_during_stream(gpt_reply):
 | 
					def close_up_code_segment_during_stream(gpt_reply):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        在gpt输出代码的中途(输出了前面的```,但还没输出完后面的```),补上后面的```
 | 
					        在gpt输出代码的中途(输出了前面的```,但还没输出完后面的```),补上后面的```
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if '```' not in gpt_reply: return gpt_reply
 | 
					    if '```' not in gpt_reply:
 | 
				
			||||||
    if gpt_reply.endswith('```'): return gpt_reply
 | 
					        return gpt_reply
 | 
				
			||||||
 | 
					    if gpt_reply.endswith('```'):
 | 
				
			||||||
 | 
					        return gpt_reply
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 排除了以上两个情况,我们
 | 
					    # 排除了以上两个情况,我们
 | 
				
			||||||
    segments = gpt_reply.split('```')
 | 
					    segments = gpt_reply.split('```')
 | 
				
			||||||
@ -193,17 +228,19 @@ def close_up_code_segment_during_stream(gpt_reply):
 | 
				
			|||||||
        return gpt_reply
 | 
					        return gpt_reply
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
def format_io(self, y):
 | 
					def format_io(self, y):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
 | 
					        将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if y is None or y == []: return []
 | 
					    if y is None or y == []:
 | 
				
			||||||
 | 
					        return []
 | 
				
			||||||
    i_ask, gpt_reply = y[-1]
 | 
					    i_ask, gpt_reply = y[-1]
 | 
				
			||||||
    i_ask = text_divide_paragraph(i_ask)  # 输入部分太自由,预处理一波
 | 
					    i_ask = text_divide_paragraph(i_ask)  # 输入部分太自由,预处理一波
 | 
				
			||||||
    gpt_reply = close_up_code_segment_during_stream(gpt_reply)  # 当代码输出半截的时候,试着补上后个```
 | 
					    gpt_reply = close_up_code_segment_during_stream(
 | 
				
			||||||
 | 
					        gpt_reply)  # 当代码输出半截的时候,试着补上后个```
 | 
				
			||||||
    y[-1] = (
 | 
					    y[-1] = (
 | 
				
			||||||
        None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code','tables']),
 | 
					        None if i_ask is None else markdown.markdown(
 | 
				
			||||||
 | 
					            i_ask, extensions=['fenced_code', 'tables']),
 | 
				
			||||||
        None if gpt_reply is None else markdown_convertion(gpt_reply)
 | 
					        None if gpt_reply is None else markdown_convertion(gpt_reply)
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    return y
 | 
					    return y
 | 
				
			||||||
@ -265,6 +302,7 @@ def extract_archive(file_path, dest_dir):
 | 
				
			|||||||
        return ''
 | 
					        return ''
 | 
				
			||||||
    return ''
 | 
					    return ''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def find_recent_files(directory):
 | 
					def find_recent_files(directory):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        me: find files that is created with in one minutes under a directory with python, write a function
 | 
					        me: find files that is created with in one minutes under a directory with python, write a function
 | 
				
			||||||
@ -278,21 +316,29 @@ def find_recent_files(directory):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    for filename in os.listdir(directory):
 | 
					    for filename in os.listdir(directory):
 | 
				
			||||||
        file_path = os.path.join(directory, filename)
 | 
					        file_path = os.path.join(directory, filename)
 | 
				
			||||||
        if file_path.endswith('.log'): continue
 | 
					        if file_path.endswith('.log'):
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
        created_time = os.path.getmtime(file_path)
 | 
					        created_time = os.path.getmtime(file_path)
 | 
				
			||||||
        if created_time >= one_minute_ago:
 | 
					        if created_time >= one_minute_ago:
 | 
				
			||||||
            if os.path.isdir(file_path): continue
 | 
					            if os.path.isdir(file_path):
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
            recent_files.append(file_path)
 | 
					            recent_files.append(file_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return recent_files
 | 
					    return recent_files
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def on_file_uploaded(files, chatbot, txt):
 | 
					def on_file_uploaded(files, chatbot, txt):
 | 
				
			||||||
    if len(files) == 0: return chatbot, txt
 | 
					    if len(files) == 0:
 | 
				
			||||||
    import shutil, os, time, glob
 | 
					        return chatbot, txt
 | 
				
			||||||
 | 
					    import shutil
 | 
				
			||||||
 | 
					    import os
 | 
				
			||||||
 | 
					    import time
 | 
				
			||||||
 | 
					    import glob
 | 
				
			||||||
    from toolbox import extract_archive
 | 
					    from toolbox import extract_archive
 | 
				
			||||||
    try: shutil.rmtree('./private_upload/')
 | 
					    try:
 | 
				
			||||||
    except: pass
 | 
					        shutil.rmtree('./private_upload/')
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
    time_tag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
 | 
					    time_tag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
 | 
				
			||||||
    os.makedirs(f'private_upload/{time_tag}', exist_ok=True)
 | 
					    os.makedirs(f'private_upload/{time_tag}', exist_ok=True)
 | 
				
			||||||
    err_msg = ''
 | 
					    err_msg = ''
 | 
				
			||||||
@ -301,12 +347,13 @@ def on_file_uploaded(files, chatbot, txt):
 | 
				
			|||||||
        shutil.copy(file.name, f'private_upload/{time_tag}/{file_origin_name}')
 | 
					        shutil.copy(file.name, f'private_upload/{time_tag}/{file_origin_name}')
 | 
				
			||||||
        err_msg += extract_archive(f'private_upload/{time_tag}/{file_origin_name}',
 | 
					        err_msg += extract_archive(f'private_upload/{time_tag}/{file_origin_name}',
 | 
				
			||||||
                                   dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract')
 | 
					                                   dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract')
 | 
				
			||||||
    moved_files = [fp for fp in glob.glob('private_upload/**/*', recursive=True)]
 | 
					    moved_files = [fp for fp in glob.glob(
 | 
				
			||||||
 | 
					        'private_upload/**/*', recursive=True)]
 | 
				
			||||||
    txt = f'private_upload/{time_tag}'
 | 
					    txt = f'private_upload/{time_tag}'
 | 
				
			||||||
    moved_files_str = '\t\n\n'.join(moved_files)
 | 
					    moved_files_str = '\t\n\n'.join(moved_files)
 | 
				
			||||||
    chatbot.append(['我上传了文件,请查收',
 | 
					    chatbot.append(['我上传了文件,请查收',
 | 
				
			||||||
                    f'[Local Message] 收到以下文件: \n\n{moved_files_str}'+
 | 
					                    f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
 | 
				
			||||||
                    f'\n\n调用路径参数已自动修正到: \n\n{txt}'+
 | 
					                    f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
 | 
				
			||||||
                    f'\n\n现在您点击任意实验功能时,以上文件将被作为输入参数'+err_msg])
 | 
					                    f'\n\n现在您点击任意实验功能时,以上文件将被作为输入参数'+err_msg])
 | 
				
			||||||
    return chatbot, txt
 | 
					    return chatbot, txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -314,17 +361,21 @@ def on_file_uploaded(files, chatbot, txt):
 | 
				
			|||||||
def on_report_generated(files, chatbot):
 | 
					def on_report_generated(files, chatbot):
 | 
				
			||||||
    from toolbox import find_recent_files
 | 
					    from toolbox import find_recent_files
 | 
				
			||||||
    report_files = find_recent_files('gpt_log')
 | 
					    report_files = find_recent_files('gpt_log')
 | 
				
			||||||
    if len(report_files) == 0: return None, chatbot
 | 
					    if len(report_files) == 0:
 | 
				
			||||||
 | 
					        return None, chatbot
 | 
				
			||||||
    # files.extend(report_files)
 | 
					    # files.extend(report_files)
 | 
				
			||||||
    chatbot.append(['汇总报告如何远程获取?', '汇总报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。'])
 | 
					    chatbot.append(['汇总报告如何远程获取?', '汇总报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。'])
 | 
				
			||||||
    return report_files, chatbot
 | 
					    return report_files, chatbot
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@lru_cache(maxsize=128)
 | 
					@lru_cache(maxsize=128)
 | 
				
			||||||
def read_single_conf_with_lru_cache(arg):
 | 
					def read_single_conf_with_lru_cache(arg):
 | 
				
			||||||
    try: r = getattr(importlib.import_module('config_private'), arg)
 | 
					    try:
 | 
				
			||||||
    except: r = getattr(importlib.import_module('config'), arg)
 | 
					        r = getattr(importlib.import_module('config_private'), arg)
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					        r = getattr(importlib.import_module('config'), arg)
 | 
				
			||||||
    # 在读取API_KEY时,检查一下是不是忘了改config
 | 
					    # 在读取API_KEY时,检查一下是不是忘了改config
 | 
				
			||||||
    if arg=='API_KEY':
 | 
					    if arg == 'API_KEY':
 | 
				
			||||||
        # 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合
 | 
					        # 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合
 | 
				
			||||||
        API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", r)
 | 
					        API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", r)
 | 
				
			||||||
        if API_MATCH:
 | 
					        if API_MATCH:
 | 
				
			||||||
@ -332,7 +383,7 @@ def read_single_conf_with_lru_cache(arg):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            assert False, "正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
 | 
					            assert False, "正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
 | 
				
			||||||
                "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)"
 | 
					                "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)"
 | 
				
			||||||
    if arg=='proxies':
 | 
					    if arg == 'proxies':
 | 
				
			||||||
        if r is None:
 | 
					        if r is None:
 | 
				
			||||||
            print('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问。建议:检查USE_PROXY选项是否修改。')
 | 
					            print('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问。建议:检查USE_PROXY选项是否修改。')
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
@ -340,6 +391,7 @@ def read_single_conf_with_lru_cache(arg):
 | 
				
			|||||||
            assert isinstance(r, dict), 'proxies格式错误,请注意proxies选项的格式,不要遗漏括号。'
 | 
					            assert isinstance(r, dict), 'proxies格式错误,请注意proxies选项的格式,不要遗漏括号。'
 | 
				
			||||||
    return r
 | 
					    return r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_conf(*args):
 | 
					def get_conf(*args):
 | 
				
			||||||
    # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
 | 
					    # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
 | 
				
			||||||
    res = []
 | 
					    res = []
 | 
				
			||||||
@ -348,14 +400,17 @@ def get_conf(*args):
 | 
				
			|||||||
        res.append(r)
 | 
					        res.append(r)
 | 
				
			||||||
    return res
 | 
					    return res
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def clear_line_break(txt):
 | 
					def clear_line_break(txt):
 | 
				
			||||||
    txt = txt.replace('\n', ' ')
 | 
					    txt = txt.replace('\n', ' ')
 | 
				
			||||||
    txt = txt.replace('  ', ' ')
 | 
					    txt = txt.replace('  ', ' ')
 | 
				
			||||||
    txt = txt.replace('  ', ' ')
 | 
					    txt = txt.replace('  ', ' ')
 | 
				
			||||||
    return txt
 | 
					    return txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DummyWith():
 | 
					class DummyWith():
 | 
				
			||||||
    def __enter__(self):
 | 
					    def __enter__(self):
 | 
				
			||||||
        return self
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __exit__(self, exc_type, exc_value, traceback):
 | 
					    def __exit__(self, exc_type, exc_value, traceback):
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user