主要代码规整化
This commit is contained in:
		
							parent
							
								
									1055fdaab7
								
							
						
					
					
						commit
						0079733bfd
					
				@ -3,7 +3,8 @@ def check_proxy(proxies):
 | 
			
		||||
    import requests
 | 
			
		||||
    proxies_https = proxies['https'] if proxies is not None else '无'
 | 
			
		||||
    try:
 | 
			
		||||
        response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4)
 | 
			
		||||
        response = requests.get("https://ipapi.co/json/",
 | 
			
		||||
                                proxies=proxies, timeout=4)
 | 
			
		||||
        data = response.json()
 | 
			
		||||
        print(f'查询代理的地理位置,返回的结果是{data}')
 | 
			
		||||
        if 'country_name' in data:
 | 
			
		||||
@ -21,9 +22,11 @@ def check_proxy(proxies):
 | 
			
		||||
 | 
			
		||||
def auto_update():
 | 
			
		||||
    from toolbox import get_conf
 | 
			
		||||
    import requests, time, json
 | 
			
		||||
    import requests
 | 
			
		||||
    import time
 | 
			
		||||
    import json
 | 
			
		||||
    proxies, = get_conf('proxies')
 | 
			
		||||
    response = requests.get("https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version", 
 | 
			
		||||
    response = requests.get("https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version",
 | 
			
		||||
                            proxies=proxies, timeout=1)
 | 
			
		||||
    remote_json_data = json.loads(response.text)
 | 
			
		||||
    remote_version = remote_json_data['version']
 | 
			
		||||
@ -31,11 +34,12 @@ def auto_update():
 | 
			
		||||
        new_feature = "新功能:" + remote_json_data["new_feature"]
 | 
			
		||||
    else:
 | 
			
		||||
        new_feature = ""
 | 
			
		||||
    with open('./version', 'r', encoding='utf8') as f: 
 | 
			
		||||
    with open('./version', 'r', encoding='utf8') as f:
 | 
			
		||||
        current_version = f.read()
 | 
			
		||||
        current_version = json.loads(current_version)['version']
 | 
			
		||||
    if (remote_version - current_version) >= 0.05:
 | 
			
		||||
        print(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}')
 | 
			
		||||
        print(
 | 
			
		||||
            f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}。{new_feature}')
 | 
			
		||||
        print('Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
 | 
			
		||||
        time.sleep(3)
 | 
			
		||||
        return
 | 
			
		||||
@ -44,8 +48,8 @@ def auto_update():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
 | 
			
		||||
    import os
 | 
			
		||||
    os.environ['no_proxy'] = '*'  # 避免代理网络产生意外污染
 | 
			
		||||
    from toolbox import get_conf
 | 
			
		||||
    proxies, = get_conf('proxies')
 | 
			
		||||
    check_proxy(proxies)
 | 
			
		||||
    
 | 
			
		||||
							
								
								
									
										11
									
								
								config.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								config.py
									
									
									
									
									
								
							@ -11,10 +11,10 @@ if USE_PROXY:
 | 
			
		||||
    # [端口] 在代理软件的设置里找。虽然不同的代理软件界面不一样,但端口号都应该在最显眼的位置上
 | 
			
		||||
 | 
			
		||||
    # 代理网络的地址,打开你的科学上网软件查看代理的协议(socks5/http)、地址(localhost)和端口(11284)
 | 
			
		||||
    proxies = { 
 | 
			
		||||
    proxies = {
 | 
			
		||||
        #          [协议]://  [地址]  :[端口]
 | 
			
		||||
        "http":  "socks5h://localhost:11284", 
 | 
			
		||||
        "https": "socks5h://localhost:11284", 
 | 
			
		||||
        "http":  "socks5h://localhost:11284",
 | 
			
		||||
        "https": "socks5h://localhost:11284",
 | 
			
		||||
    }
 | 
			
		||||
else:
 | 
			
		||||
    proxies = None
 | 
			
		||||
@ -25,7 +25,7 @@ else:
 | 
			
		||||
CHATBOT_HEIGHT = 1115
 | 
			
		||||
 | 
			
		||||
# 窗口布局
 | 
			
		||||
LAYOUT = "LEFT-RIGHT" # "LEFT-RIGHT"(左右布局) # "TOP-DOWN"(上下布局)
 | 
			
		||||
LAYOUT = "LEFT-RIGHT"  # "LEFT-RIGHT"(左右布局) # "TOP-DOWN"(上下布局)
 | 
			
		||||
 | 
			
		||||
# 发送请求到OpenAI后,等待多久判定为超时
 | 
			
		||||
TIMEOUT_SECONDS = 25
 | 
			
		||||
@ -46,4 +46,5 @@ API_URL = "https://api.openai.com/v1/chat/completions"
 | 
			
		||||
CONCURRENT_COUNT = 100
 | 
			
		||||
 | 
			
		||||
# 设置用户名和密码(相关功能不稳定,与gradio版本和网络都相关,如果本地使用不建议加这个)
 | 
			
		||||
AUTHENTICATION = [] # [("username", "password"), ("username2", "password2"), ...]
 | 
			
		||||
# [("username", "password"), ("username2", "password2"), ...]
 | 
			
		||||
AUTHENTICATION = []
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,7 @@
 | 
			
		||||
# 默认按钮颜色是 secondary
 | 
			
		||||
from toolbox import clear_line_break
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_core_functions():
 | 
			
		||||
    return {
 | 
			
		||||
        "英语学术润色": {
 | 
			
		||||
@ -11,12 +12,12 @@ def get_core_functions():
 | 
			
		||||
            "Prefix":   r"Below is a paragraph from an academic paper. Polish the writing to meet the academic style, " +
 | 
			
		||||
                        r"improve the spelling, grammar, clarity, concision and overall readability. When necessary, rewrite the whole sentence. " +
 | 
			
		||||
                        r"Furthermore, list all modification and explain the reasons to do so in markdown table." + "\n\n",
 | 
			
		||||
            # 后语 
 | 
			
		||||
            # 后语
 | 
			
		||||
            "Suffix":   r"",
 | 
			
		||||
            "Color":    r"secondary",    # 按钮颜色
 | 
			
		||||
        },
 | 
			
		||||
        "中文学术润色": {
 | 
			
		||||
            "Prefix":   r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性," + 
 | 
			
		||||
            "Prefix":   r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性," +
 | 
			
		||||
                        r"同时分解长句,减少重复,并提供改进建议。请只提供文本的更正版本,避免包括解释。请编辑以下文本" + "\n\n",
 | 
			
		||||
            "Suffix":   r"",
 | 
			
		||||
        },
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,5 @@
 | 
			
		||||
from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
 | 
			
		||||
from toolbox import HotReload  # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_crazy_functions():
 | 
			
		||||
    ###################### 第一组插件 ###########################
 | 
			
		||||
@ -81,7 +82,8 @@ def get_crazy_functions():
 | 
			
		||||
        "[仅供开发调试] 批量总结PDF文档": {
 | 
			
		||||
            "Color": "stop",
 | 
			
		||||
            "AsButton": False,  # 加入下拉菜单中
 | 
			
		||||
            "Function": HotReload(批量总结PDF文档) # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效
 | 
			
		||||
            # HotReload 的意思是热更新,修改函数插件代码后,不需要重启程序,代码直接生效
 | 
			
		||||
            "Function": HotReload(批量总结PDF文档)
 | 
			
		||||
        },
 | 
			
		||||
        "[仅供开发调试] 批量总结PDF文档pdfminer": {
 | 
			
		||||
            "Color": "stop",
 | 
			
		||||
@ -109,9 +111,5 @@ def get_crazy_functions():
 | 
			
		||||
    except Exception as err:
 | 
			
		||||
        print(f'[下载arxiv论文并翻译摘要] 插件导入失败 {str(err)}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    ###################### 第n组插件 ###########################
 | 
			
		||||
    return function_plugins
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								theme.py
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								theme.py
									
									
									
									
									
								
							@ -1,4 +1,4 @@
 | 
			
		||||
import gradio as gr 
 | 
			
		||||
import gradio as gr
 | 
			
		||||
 | 
			
		||||
# gradio可用颜色列表
 | 
			
		||||
# gr.themes.utils.colors.slate (石板色)
 | 
			
		||||
@ -24,14 +24,16 @@ import gradio as gr
 | 
			
		||||
# gr.themes.utils.colors.pink (粉红色)
 | 
			
		||||
# gr.themes.utils.colors.rose (玫瑰色)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def adjust_theme():
 | 
			
		||||
    try: 
 | 
			
		||||
    try:
 | 
			
		||||
        color_er = gr.themes.utils.colors.fuchsia
 | 
			
		||||
        set_theme = gr.themes.Default( 
 | 
			
		||||
                        primary_hue=gr.themes.utils.colors.orange,
 | 
			
		||||
                        neutral_hue=gr.themes.utils.colors.gray,
 | 
			
		||||
                        font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui", "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")], 
 | 
			
		||||
                        font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")])
 | 
			
		||||
        set_theme = gr.themes.Default(
 | 
			
		||||
            primary_hue=gr.themes.utils.colors.orange,
 | 
			
		||||
            neutral_hue=gr.themes.utils.colors.gray,
 | 
			
		||||
            font=["sans-serif", "Microsoft YaHei", "ui-sans-serif", "system-ui",
 | 
			
		||||
                  "sans-serif", gr.themes.utils.fonts.GoogleFont("Source Sans Pro")],
 | 
			
		||||
            font_mono=["ui-monospace", "Consolas", "monospace", gr.themes.utils.fonts.GoogleFont("IBM Plex Mono")])
 | 
			
		||||
        set_theme.set(
 | 
			
		||||
            # Colors
 | 
			
		||||
            input_background_fill_dark="*neutral_800",
 | 
			
		||||
@ -77,10 +79,12 @@ def adjust_theme():
 | 
			
		||||
            button_cancel_text_color=color_er.c600,
 | 
			
		||||
            button_cancel_text_color_dark="white",
 | 
			
		||||
        )
 | 
			
		||||
    except: 
 | 
			
		||||
        set_theme = None; print('gradio版本较旧, 不能自定义字体和颜色')
 | 
			
		||||
    except:
 | 
			
		||||
        set_theme = None
 | 
			
		||||
        print('gradio版本较旧, 不能自定义字体和颜色')
 | 
			
		||||
    return set_theme
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
advanced_css = """
 | 
			
		||||
/* 设置表格的外边距为1em,内部单元格之间边框合并,空单元格显示. */
 | 
			
		||||
.markdown-body table {
 | 
			
		||||
@ -149,4 +153,4 @@ advanced_css = """
 | 
			
		||||
    padding: 1em;
 | 
			
		||||
    margin: 1em 2em 1em 0.5em;
 | 
			
		||||
}
 | 
			
		||||
"""
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										149
									
								
								toolbox.py
									
									
									
									
									
								
							
							
						
						
									
										149
									
								
								toolbox.py
									
									
									
									
									
								
							@ -1,14 +1,23 @@
 | 
			
		||||
import markdown, mdtex2html, threading, importlib, traceback, importlib, inspect, re
 | 
			
		||||
import markdown
 | 
			
		||||
import mdtex2html
 | 
			
		||||
import threading
 | 
			
		||||
import importlib
 | 
			
		||||
import traceback
 | 
			
		||||
import importlib
 | 
			
		||||
import inspect
 | 
			
		||||
import re
 | 
			
		||||
from show_math import convert as convert_math
 | 
			
		||||
from functools import wraps, lru_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ArgsGeneralWrapper(f):
 | 
			
		||||
    """
 | 
			
		||||
        装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
 | 
			
		||||
    """
 | 
			
		||||
    def decorated(txt, txt2, *args, **kwargs):
 | 
			
		||||
        txt_passon = txt
 | 
			
		||||
        if txt == "" and txt2 != "": txt_passon = txt2
 | 
			
		||||
        if txt == "" and txt2 != "":
 | 
			
		||||
            txt_passon = txt2
 | 
			
		||||
        yield from f(txt_passon, *args, **kwargs)
 | 
			
		||||
    return decorated
 | 
			
		||||
 | 
			
		||||
@ -18,7 +27,7 @@ def get_reduce_token_percent(text):
 | 
			
		||||
        # text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
 | 
			
		||||
        pattern = r"(\d+)\s+tokens\b"
 | 
			
		||||
        match = re.findall(pattern, text)
 | 
			
		||||
        EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
 | 
			
		||||
        EXCEED_ALLO = 500  # 稍微留一点余地,否则在回复时会因余量太少出问题
 | 
			
		||||
        max_limit = float(match[0]) - EXCEED_ALLO
 | 
			
		||||
        current_tokens = float(match[1])
 | 
			
		||||
        ratio = max_limit/current_tokens
 | 
			
		||||
@ -27,6 +36,7 @@ def get_reduce_token_percent(text):
 | 
			
		||||
    except:
 | 
			
		||||
        return 0.5, '不详'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temperature, history=[], sys_prompt='', long_connection=True):
 | 
			
		||||
    """
 | 
			
		||||
        调用简单的predict_no_ui接口,但是依然保留了些许界面心跳功能,当对话太长时,会自动采用二分法截断
 | 
			
		||||
@ -46,21 +56,26 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp
 | 
			
		||||
    # list就是最简单的mutable结构,我们第一个位置放gpt输出,第二个位置传递报错信息
 | 
			
		||||
    mutable = [None, '']
 | 
			
		||||
    # multi-threading worker
 | 
			
		||||
 | 
			
		||||
    def mt(i_say, history):
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                if long_connection:
 | 
			
		||||
                    mutable[0] = predict_no_ui_long_connection(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
 | 
			
		||||
                    mutable[0] = predict_no_ui_long_connection(
 | 
			
		||||
                        inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
 | 
			
		||||
                else:
 | 
			
		||||
                    mutable[0] = predict_no_ui(inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
 | 
			
		||||
                    mutable[0] = predict_no_ui(
 | 
			
		||||
                        inputs=i_say, top_p=top_p, temperature=temperature, history=history, sys_prompt=sys_prompt)
 | 
			
		||||
                break
 | 
			
		||||
            except ConnectionAbortedError as token_exceeded_error:
 | 
			
		||||
                # 尝试计算比例,尽可能多地保留文本
 | 
			
		||||
                p_ratio, n_exceed = get_reduce_token_percent(str(token_exceeded_error))
 | 
			
		||||
                p_ratio, n_exceed = get_reduce_token_percent(
 | 
			
		||||
                    str(token_exceeded_error))
 | 
			
		||||
                if len(history) > 0:
 | 
			
		||||
                    history = [his[     int(len(his)    *p_ratio):      ] for his in history if his is not None]
 | 
			
		||||
                    history = [his[int(len(his) * p_ratio):]
 | 
			
		||||
                               for his in history if his is not None]
 | 
			
		||||
                else:
 | 
			
		||||
                    i_say = i_say[:     int(len(i_say)  *p_ratio)     ]
 | 
			
		||||
                    i_say = i_say[:     int(len(i_say) * p_ratio)]
 | 
			
		||||
                mutable[1] = f'警告,文本过长将进行截断,Token溢出数:{n_exceed},截断比例:{(1-p_ratio):.0%}。'
 | 
			
		||||
            except TimeoutError as e:
 | 
			
		||||
                mutable[0] = '[Local Message] 请求超时。'
 | 
			
		||||
@ -69,42 +84,51 @@ def predict_no_ui_but_counting_down(i_say, i_say_show_user, chatbot, top_p, temp
 | 
			
		||||
                mutable[0] = f'[Local Message] 异常:{str(e)}.'
 | 
			
		||||
                raise RuntimeError(f'[Local Message] 异常:{str(e)}.')
 | 
			
		||||
    # 创建新线程发出http请求
 | 
			
		||||
    thread_name = threading.Thread(target=mt, args=(i_say, history)); thread_name.start()
 | 
			
		||||
    thread_name = threading.Thread(target=mt, args=(i_say, history))
 | 
			
		||||
    thread_name.start()
 | 
			
		||||
    # 原来的线程则负责持续更新UI,实现一个超时倒计时,并等待新线程的任务完成
 | 
			
		||||
    cnt = 0
 | 
			
		||||
    while thread_name.is_alive():
 | 
			
		||||
        cnt += 1
 | 
			
		||||
        chatbot[-1] = (i_say_show_user, f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt%4)))
 | 
			
		||||
        chatbot[-1] = (i_say_show_user,
 | 
			
		||||
                       f"[Local Message] {mutable[1]}waiting gpt response {cnt}/{TIMEOUT_SECONDS*2*(MAX_RETRY+1)}"+''.join(['.']*(cnt % 4)))
 | 
			
		||||
        yield chatbot, history, '正常'
 | 
			
		||||
        time.sleep(1)
 | 
			
		||||
    # 把gpt的输出从mutable中取出来
 | 
			
		||||
    gpt_say = mutable[0]
 | 
			
		||||
    if gpt_say=='[Local Message] Failed with timeout.': raise TimeoutError
 | 
			
		||||
    if gpt_say == '[Local Message] Failed with timeout.':
 | 
			
		||||
        raise TimeoutError
 | 
			
		||||
    return gpt_say
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_results_to_file(history, file_name=None):
 | 
			
		||||
    """
 | 
			
		||||
        将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
 | 
			
		||||
    """
 | 
			
		||||
    import os, time
 | 
			
		||||
    import os
 | 
			
		||||
    import time
 | 
			
		||||
    if file_name is None:
 | 
			
		||||
        # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
			
		||||
        file_name = 'chatGPT分析报告' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
			
		||||
        file_name = 'chatGPT分析报告' + \
 | 
			
		||||
            time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
			
		||||
    os.makedirs('./gpt_log/', exist_ok=True)
 | 
			
		||||
    with open(f'./gpt_log/{file_name}', 'w', encoding = 'utf8') as f:
 | 
			
		||||
    with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
 | 
			
		||||
        f.write('# chatGPT 分析报告\n')
 | 
			
		||||
        for i, content in enumerate(history):
 | 
			
		||||
            try:    # 这个bug没找到触发条件,暂时先这样顶一下
 | 
			
		||||
                if type(content) != str: content = str(content)
 | 
			
		||||
                if type(content) != str:
 | 
			
		||||
                    content = str(content)
 | 
			
		||||
            except:
 | 
			
		||||
                continue
 | 
			
		||||
            if i%2==0: f.write('## ')
 | 
			
		||||
            if i % 2 == 0:
 | 
			
		||||
                f.write('## ')
 | 
			
		||||
            f.write(content)
 | 
			
		||||
            f.write('\n\n')
 | 
			
		||||
    res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
 | 
			
		||||
    print(res)
 | 
			
		||||
    return res
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def regular_txt_to_markdown(text):
 | 
			
		||||
    """
 | 
			
		||||
        将普通文本转换为Markdown格式的文本。
 | 
			
		||||
@ -114,6 +138,7 @@ def regular_txt_to_markdown(text):
 | 
			
		||||
    text = text.replace('\n\n\n', '\n\n')
 | 
			
		||||
    return text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def CatchException(f):
 | 
			
		||||
    """
 | 
			
		||||
        装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
 | 
			
		||||
@ -127,11 +152,14 @@ def CatchException(f):
 | 
			
		||||
            from toolbox import get_conf
 | 
			
		||||
            proxies, = get_conf('proxies')
 | 
			
		||||
            tb_str = '```\n' + traceback.format_exc() + '```'
 | 
			
		||||
            if chatbot is None or len(chatbot) == 0: chatbot = [["插件调度异常","异常原因"]]
 | 
			
		||||
            chatbot[-1] = (chatbot[-1][0], f"[Local Message] 实验性函数调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
 | 
			
		||||
            if chatbot is None or len(chatbot) == 0:
 | 
			
		||||
                chatbot = [["插件调度异常", "异常原因"]]
 | 
			
		||||
            chatbot[-1] = (chatbot[-1][0],
 | 
			
		||||
                           f"[Local Message] 实验性函数调用出错: \n\n{tb_str} \n\n当前代理可用性: \n\n{check_proxy(proxies)}")
 | 
			
		||||
            yield chatbot, history, f'异常 {e}'
 | 
			
		||||
    return decorated
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def HotReload(f):
 | 
			
		||||
    """
 | 
			
		||||
        装饰器函数,实现函数插件热更新
 | 
			
		||||
@ -143,12 +171,15 @@ def HotReload(f):
 | 
			
		||||
        yield from f_hot_reload(*args, **kwargs)
 | 
			
		||||
    return decorated
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def report_execption(chatbot, history, a, b):
 | 
			
		||||
    """
 | 
			
		||||
        向chatbot中添加错误信息
 | 
			
		||||
    """
 | 
			
		||||
    chatbot.append((a, b))
 | 
			
		||||
    history.append(a); history.append(b)
 | 
			
		||||
    history.append(a)
 | 
			
		||||
    history.append(b)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def text_divide_paragraph(text):
 | 
			
		||||
    """
 | 
			
		||||
@ -165,6 +196,7 @@ def text_divide_paragraph(text):
 | 
			
		||||
        text = "</br>".join(lines)
 | 
			
		||||
        return text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def markdown_convertion(txt):
 | 
			
		||||
    """
 | 
			
		||||
        将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
 | 
			
		||||
@ -172,16 +204,19 @@ def markdown_convertion(txt):
 | 
			
		||||
    pre = '<div class="markdown-body">'
 | 
			
		||||
    suf = '</div>'
 | 
			
		||||
    if ('$' in txt) and ('```' not in txt):
 | 
			
		||||
        return pre + markdown.markdown(txt,extensions=['fenced_code','tables']) + '<br><br>' + markdown.markdown(convert_math(txt, splitParagraphs=False),extensions=['fenced_code','tables']) + suf
 | 
			
		||||
        return pre + markdown.markdown(txt, extensions=['fenced_code', 'tables']) + '<br><br>' + markdown.markdown(convert_math(txt, splitParagraphs=False), extensions=['fenced_code', 'tables']) + suf
 | 
			
		||||
    else:
 | 
			
		||||
        return pre + markdown.markdown(txt,extensions=['fenced_code','tables']) + suf
 | 
			
		||||
        return pre + markdown.markdown(txt, extensions=['fenced_code', 'tables']) + suf
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def close_up_code_segment_during_stream(gpt_reply):
 | 
			
		||||
    """
 | 
			
		||||
        在gpt输出代码的中途(输出了前面的```,但还没输出完后面的```),补上后面的```
 | 
			
		||||
    """
 | 
			
		||||
    if '```' not in gpt_reply: return gpt_reply
 | 
			
		||||
    if gpt_reply.endswith('```'): return gpt_reply
 | 
			
		||||
    if '```' not in gpt_reply:
 | 
			
		||||
        return gpt_reply
 | 
			
		||||
    if gpt_reply.endswith('```'):
 | 
			
		||||
        return gpt_reply
 | 
			
		||||
 | 
			
		||||
    # 排除了以上两个情况,我们
 | 
			
		||||
    segments = gpt_reply.split('```')
 | 
			
		||||
@ -191,19 +226,21 @@ def close_up_code_segment_during_stream(gpt_reply):
 | 
			
		||||
        return gpt_reply+'\n```'
 | 
			
		||||
    else:
 | 
			
		||||
        return gpt_reply
 | 
			
		||||
 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def format_io(self, y):
 | 
			
		||||
    """
 | 
			
		||||
        将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
 | 
			
		||||
    """
 | 
			
		||||
    if y is None or y == []: return []
 | 
			
		||||
    if y is None or y == []:
 | 
			
		||||
        return []
 | 
			
		||||
    i_ask, gpt_reply = y[-1]
 | 
			
		||||
    i_ask = text_divide_paragraph(i_ask) # 输入部分太自由,预处理一波
 | 
			
		||||
    gpt_reply = close_up_code_segment_during_stream(gpt_reply)  # 当代码输出半截的时候,试着补上后个```
 | 
			
		||||
    i_ask = text_divide_paragraph(i_ask)  # 输入部分太自由,预处理一波
 | 
			
		||||
    gpt_reply = close_up_code_segment_during_stream(
 | 
			
		||||
        gpt_reply)  # 当代码输出半截的时候,试着补上后个```
 | 
			
		||||
    y[-1] = (
 | 
			
		||||
        None if i_ask is None else markdown.markdown(i_ask, extensions=['fenced_code','tables']),
 | 
			
		||||
        None if i_ask is None else markdown.markdown(
 | 
			
		||||
            i_ask, extensions=['fenced_code', 'tables']),
 | 
			
		||||
        None if gpt_reply is None else markdown_convertion(gpt_reply)
 | 
			
		||||
    )
 | 
			
		||||
    return y
 | 
			
		||||
@ -265,6 +302,7 @@ def extract_archive(file_path, dest_dir):
 | 
			
		||||
        return ''
 | 
			
		||||
    return ''
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_recent_files(directory):
 | 
			
		||||
    """
 | 
			
		||||
        me: find files that is created with in one minutes under a directory with python, write a function
 | 
			
		||||
@ -278,21 +316,29 @@ def find_recent_files(directory):
 | 
			
		||||
 | 
			
		||||
    for filename in os.listdir(directory):
 | 
			
		||||
        file_path = os.path.join(directory, filename)
 | 
			
		||||
        if file_path.endswith('.log'): continue
 | 
			
		||||
        if file_path.endswith('.log'):
 | 
			
		||||
            continue
 | 
			
		||||
        created_time = os.path.getmtime(file_path)
 | 
			
		||||
        if created_time >= one_minute_ago:
 | 
			
		||||
            if os.path.isdir(file_path): continue
 | 
			
		||||
            if os.path.isdir(file_path):
 | 
			
		||||
                continue
 | 
			
		||||
            recent_files.append(file_path)
 | 
			
		||||
 | 
			
		||||
    return recent_files
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_file_uploaded(files, chatbot, txt):
 | 
			
		||||
    if len(files) == 0: return chatbot, txt
 | 
			
		||||
    import shutil, os, time, glob
 | 
			
		||||
    if len(files) == 0:
 | 
			
		||||
        return chatbot, txt
 | 
			
		||||
    import shutil
 | 
			
		||||
    import os
 | 
			
		||||
    import time
 | 
			
		||||
    import glob
 | 
			
		||||
    from toolbox import extract_archive
 | 
			
		||||
    try: shutil.rmtree('./private_upload/')
 | 
			
		||||
    except: pass
 | 
			
		||||
    try:
 | 
			
		||||
        shutil.rmtree('./private_upload/')
 | 
			
		||||
    except:
 | 
			
		||||
        pass
 | 
			
		||||
    time_tag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
 | 
			
		||||
    os.makedirs(f'private_upload/{time_tag}', exist_ok=True)
 | 
			
		||||
    err_msg = ''
 | 
			
		||||
@ -300,13 +346,14 @@ def on_file_uploaded(files, chatbot, txt):
 | 
			
		||||
        file_origin_name = os.path.basename(file.orig_name)
 | 
			
		||||
        shutil.copy(file.name, f'private_upload/{time_tag}/{file_origin_name}')
 | 
			
		||||
        err_msg += extract_archive(f'private_upload/{time_tag}/{file_origin_name}',
 | 
			
		||||
                        dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract')
 | 
			
		||||
    moved_files = [fp for fp in glob.glob('private_upload/**/*', recursive=True)]
 | 
			
		||||
                                   dest_dir=f'private_upload/{time_tag}/{file_origin_name}.extract')
 | 
			
		||||
    moved_files = [fp for fp in glob.glob(
 | 
			
		||||
        'private_upload/**/*', recursive=True)]
 | 
			
		||||
    txt = f'private_upload/{time_tag}'
 | 
			
		||||
    moved_files_str = '\t\n\n'.join(moved_files)
 | 
			
		||||
    chatbot.append(['我上传了文件,请查收',
 | 
			
		||||
                    f'[Local Message] 收到以下文件: \n\n{moved_files_str}'+
 | 
			
		||||
                    f'\n\n调用路径参数已自动修正到: \n\n{txt}'+
 | 
			
		||||
                    f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
 | 
			
		||||
                    f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
 | 
			
		||||
                    f'\n\n现在您点击任意实验功能时,以上文件将被作为输入参数'+err_msg])
 | 
			
		||||
    return chatbot, txt
 | 
			
		||||
 | 
			
		||||
@ -314,32 +361,37 @@ def on_file_uploaded(files, chatbot, txt):
 | 
			
		||||
def on_report_generated(files, chatbot):
 | 
			
		||||
    from toolbox import find_recent_files
 | 
			
		||||
    report_files = find_recent_files('gpt_log')
 | 
			
		||||
    if len(report_files) == 0: return None, chatbot
 | 
			
		||||
    if len(report_files) == 0:
 | 
			
		||||
        return None, chatbot
 | 
			
		||||
    # files.extend(report_files)
 | 
			
		||||
    chatbot.append(['汇总报告如何远程获取?', '汇总报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。'])
 | 
			
		||||
    return report_files, chatbot
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@lru_cache(maxsize=128)
 | 
			
		||||
def read_single_conf_with_lru_cache(arg):
 | 
			
		||||
    try: r = getattr(importlib.import_module('config_private'), arg)
 | 
			
		||||
    except: r = getattr(importlib.import_module('config'), arg)
 | 
			
		||||
    try:
 | 
			
		||||
        r = getattr(importlib.import_module('config_private'), arg)
 | 
			
		||||
    except:
 | 
			
		||||
        r = getattr(importlib.import_module('config'), arg)
 | 
			
		||||
    # 在读取API_KEY时,检查一下是不是忘了改config
 | 
			
		||||
    if arg=='API_KEY':
 | 
			
		||||
    if arg == 'API_KEY':
 | 
			
		||||
        # 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合
 | 
			
		||||
        API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", r)
 | 
			
		||||
        if API_MATCH:
 | 
			
		||||
            print(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
 | 
			
		||||
        else:
 | 
			
		||||
            assert False, "正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
 | 
			
		||||
                        "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)"
 | 
			
		||||
    if arg=='proxies':
 | 
			
		||||
        if r is None: 
 | 
			
		||||
                "(如果您刚更新过代码,请确保旧版config_private文件中没有遗留任何新增键值)"
 | 
			
		||||
    if arg == 'proxies':
 | 
			
		||||
        if r is None:
 | 
			
		||||
            print('[PROXY] 网络代理状态:未配置。无代理状态下很可能无法访问。建议:检查USE_PROXY选项是否修改。')
 | 
			
		||||
        else: 
 | 
			
		||||
        else:
 | 
			
		||||
            print('[PROXY] 网络代理状态:已配置。配置信息如下:', r)
 | 
			
		||||
            assert isinstance(r, dict), 'proxies格式错误,请注意proxies选项的格式,不要遗漏括号。'
 | 
			
		||||
    return r
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_conf(*args):
 | 
			
		||||
    # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
 | 
			
		||||
    res = []
 | 
			
		||||
@ -348,14 +400,17 @@ def get_conf(*args):
 | 
			
		||||
        res.append(r)
 | 
			
		||||
    return res
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clear_line_break(txt):
 | 
			
		||||
    txt = txt.replace('\n', ' ')
 | 
			
		||||
    txt = txt.replace('  ', ' ')
 | 
			
		||||
    txt = txt.replace('  ', ' ')
 | 
			
		||||
    return txt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DummyWith():
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, exc_type, exc_value, traceback):
 | 
			
		||||
        return
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user