更多模型切换

This commit is contained in:
Your Name 2023-04-17 21:23:03 +08:00
parent 03ba072c16
commit 9bd8511ba4
5 changed files with 166 additions and 141 deletions

View File

@ -46,14 +46,12 @@ WEB_PORT = -1
MAX_RETRY = 2 MAX_RETRY = 2
# OpenAI模型选择是gpt4现在只对申请成功的人开放 # OpenAI模型选择是gpt4现在只对申请成功的人开放
LLM_MODEL = "gpt-3.5-turbo" # 可选 "chatglm", "tgui:anymodel@localhost:7865" LLM_MODEL = "gpt-3.5-turbo" # 可选 "chatglm"
AVAIL_LLM_MODELS = ["gpt-3.5-turbo", "chatglm", "gpt-4", "api2d-gpt-4", "api2d-gpt-3.5-turbo"]
# 本地LLM模型如ChatGLM的执行方式 CPU/GPU # 本地LLM模型如ChatGLM的执行方式 CPU/GPU
LOCAL_MODEL_DEVICE = "cpu" # 可选 "cuda" LOCAL_MODEL_DEVICE = "cpu" # 可选 "cuda"
# OpenAI的API_URL
API_URL = "https://api.openai.com/v1/chat/completions"
# 设置gradio的并行线程数不需要修改 # 设置gradio的并行线程数不需要修改
CONCURRENT_COUNT = 100 CONCURRENT_COUNT = 100

View File

@ -5,8 +5,8 @@ def main():
from request_llm.bridge_all import predict from request_llm.bridge_all import predict
from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, DummyWith from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, DummyWith
# 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到 # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY = \ proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY, AVAIL_LLM_MODELS = \
get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY') get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY', 'AVAIL_LLM_MODELS')
# 如果WEB_PORT是-1, 则随机选取WEB端口 # 如果WEB_PORT是-1, 则随机选取WEB端口
PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
@ -101,7 +101,7 @@ def main():
temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0, step=0.01, interactive=True, label="Temperature",) temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0, step=0.01, interactive=True, label="Temperature",)
max_length_sl = gr.Slider(minimum=256, maximum=4096, value=512, step=1, interactive=True, label="MaxLength",) max_length_sl = gr.Slider(minimum=256, maximum=4096, value=512, step=1, interactive=True, label="MaxLength",)
checkboxes = gr.CheckboxGroup(["基础功能区", "函数插件区", "底部输入区", "输入清除键"], value=["基础功能区", "函数插件区"], label="显示/隐藏功能区") checkboxes = gr.CheckboxGroup(["基础功能区", "函数插件区", "底部输入区", "输入清除键"], value=["基础功能区", "函数插件区"], label="显示/隐藏功能区")
md_dropdown = gr.Dropdown(["gpt-3.5-turbo", "chatglm"], value=LLM_MODEL, label="").style(container=False) md_dropdown = gr.Dropdown(AVAIL_LLM_MODELS, value=LLM_MODEL, label="").style(container=False)
gr.Markdown(description) gr.Markdown(description)
with gr.Accordion("备选输入区", open=True, visible=False) as area_input_secondary: with gr.Accordion("备选输入区", open=True, visible=False) as area_input_secondary:

View File

@ -21,38 +21,42 @@ from .bridge_chatglm import predict as chatglm_ui
from .bridge_tgui import predict_no_ui_long_connection as tgui_noui from .bridge_tgui import predict_no_ui_long_connection as tgui_noui
from .bridge_tgui import predict as tgui_ui from .bridge_tgui import predict as tgui_ui
methods = { colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
"openai-no-ui": chatgpt_noui,
"openai-ui": chatgpt_ui,
"chatglm-no-ui": chatglm_noui,
"chatglm-ui": chatglm_ui,
"tgui-no-ui": tgui_noui,
"tgui-ui": tgui_ui,
}
model_info = { model_info = {
# openai # openai
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
"fn_with_ui": chatgpt_ui,
"fn_without_ui": chatgpt_noui,
"endpoint": "https://api.openai.com/v1/chat/completions",
"max_token": 4096, "max_token": 4096,
"tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"), "tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
"token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())), "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())),
}, },
"gpt-4": { "gpt-4": {
"fn_with_ui": chatgpt_ui,
"fn_without_ui": chatgpt_noui,
"endpoint": "https://api.openai.com/v1/chat/completions",
"max_token": 4096, "max_token": 4096,
"tokenizer": tiktoken.encoding_for_model("gpt-4"), "tokenizer": tiktoken.encoding_for_model("gpt-4"),
"token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-4").encode(txt, disallowed_special=())), "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-4").encode(txt, disallowed_special=())),
}, },
# api_2d # api_2d
"gpt-3.5-turbo-api2d": { "api2d-gpt-3.5-turbo": {
"fn_with_ui": chatgpt_ui,
"fn_without_ui": chatgpt_noui,
"endpoint": "https://openai.api2d.net/v1/chat/completions",
"max_token": 4096, "max_token": 4096,
"tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"), "tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
"token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())), "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())),
}, },
"gpt-4-api2d": { "api2d-gpt-4": {
"fn_with_ui": chatgpt_ui,
"fn_without_ui": chatgpt_noui,
"endpoint": "https://openai.api2d.net/v1/chat/completions",
"max_token": 4096, "max_token": 4096,
"tokenizer": tiktoken.encoding_for_model("gpt-4"), "tokenizer": tiktoken.encoding_for_model("gpt-4"),
"token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-4").encode(txt, disallowed_special=())), "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-4").encode(txt, disallowed_special=())),
@ -60,18 +64,20 @@ model_info = {
# chatglm # chatglm
"chatglm": { "chatglm": {
"fn_with_ui": chatglm_ui,
"fn_without_ui": chatglm_noui,
"endpoint": None,
"max_token": 1024, "max_token": 1024,
"tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"), "tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
"token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())), "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())),
}, },
} }
def LLM_CATCH_EXCEPTION(f): def LLM_CATCH_EXCEPTION(f):
""" """
装饰器函数将错误显示出来 装饰器函数将错误显示出来
""" """
def decorated(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience): def decorated(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience):
try: try:
@ -85,21 +91,20 @@ def LLM_CATCH_EXCEPTION(f):
return tb_str return tb_str
return decorated return decorated
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience=False): def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience=False):
""" """
发送至LLM等待回复一次性完成不显示中间过程但内部用stream的方法避免中途网线被掐 发送至LLM等待回复一次性完成不显示中间过程但内部用stream的方法避免中途网线被掐
inputs inputs
是本次问询的输入 是本次问询的输入
sys_prompt: sys_prompt:
系统静默prompt 系统静默prompt
llm_kwargs llm_kwargs
LLM的内部调优参数 LLM的内部调优参数
history history
是之前的对话列表 是之前的对话列表
observe_window = None observe_window = None
用于负责跨越线程传递已经输出的部分大部分时候仅仅为了fancy的视觉效果留空即可observe_window[0]观测窗observe_window[1]看门狗 用于负责跨越线程传递已经输出的部分大部分时候仅仅为了fancy的视觉效果留空即可observe_window[0]观测窗observe_window[1]看门狗
""" """
import threading, time, copy import threading, time, copy
@ -109,12 +114,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, obser
assert not model.startswith("tgui"), "TGUI不支持函数插件的实现" assert not model.startswith("tgui"), "TGUI不支持函数插件的实现"
# 如果只询问1个大语言模型 # 如果只询问1个大语言模型
if model.startswith('gpt'): method = model_info[model]["fn_without_ui"]
method = methods['openai-no-ui']
elif model == 'chatglm':
method = methods['chatglm-no-ui']
elif model.startswith('tgui'):
method = methods['tgui-no-ui']
return method(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience) return method(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience)
else: else:
# 如果同时询问多个大语言模型: # 如果同时询问多个大语言模型:
@ -129,12 +129,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, obser
futures = [] futures = []
for i in range(n_model): for i in range(n_model):
model = models[i] model = models[i]
if model.startswith('gpt'): method = model_info[model]["fn_without_ui"]
method = methods['openai-no-ui']
elif model == 'chatglm':
method = methods['chatglm-no-ui']
elif model.startswith('tgui'):
method = methods['tgui-no-ui']
llm_kwargs_feedin = copy.deepcopy(llm_kwargs) llm_kwargs_feedin = copy.deepcopy(llm_kwargs)
llm_kwargs_feedin['llm_model'] = model llm_kwargs_feedin['llm_model'] = model
future = executor.submit(LLM_CATCH_EXCEPTION(method), inputs, llm_kwargs_feedin, history, sys_prompt, window_mutex[i], console_slience) future = executor.submit(LLM_CATCH_EXCEPTION(method), inputs, llm_kwargs_feedin, history, sys_prompt, window_mutex[i], console_slience)
@ -176,20 +171,15 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, obser
def predict(inputs, llm_kwargs, *args, **kwargs): def predict(inputs, llm_kwargs, *args, **kwargs):
""" """
发送至LLM流式获取输出 发送至LLM流式获取输出
用于基础的对话功能 用于基础的对话功能
inputs 是本次问询的输入 inputs 是本次问询的输入
top_p, temperature是LLM的内部调优参数 top_p, temperature是LLM的内部调优参数
history 是之前的对话列表注意无论是inputs还是history内容太长了都会触发token数量溢出的错误 history 是之前的对话列表注意无论是inputs还是history内容太长了都会触发token数量溢出的错误
chatbot 为WebUI中显示的对话列表修改它然后yeild出去可以直接修改对话界面内容 chatbot 为WebUI中显示的对话列表修改它然后yeild出去可以直接修改对话界面内容
additional_fn代表点击的哪个按钮按钮见functional.py additional_fn代表点击的哪个按钮按钮见functional.py
""" """
if llm_kwargs['llm_model'].startswith('gpt'):
method = methods['openai-ui']
elif llm_kwargs['llm_model'] == 'chatglm':
method = methods['chatglm-ui']
elif llm_kwargs['llm_model'].startswith('tgui'):
method = methods['tgui-ui']
method = model_info[llm_kwargs['llm_model']]["fn_with_ui"]
yield from method(inputs, llm_kwargs, *args, **kwargs) yield from method(inputs, llm_kwargs, *args, **kwargs)

View File

@ -21,9 +21,9 @@ import importlib
# config_private.py放自己的秘密如API和代理网址 # config_private.py放自己的秘密如API和代理网址
# 读取时首先看是否存在私密的config_private配置文件不受git管控如果有则覆盖原config文件 # 读取时首先看是否存在私密的config_private配置文件不受git管控如果有则覆盖原config文件
from toolbox import get_conf, update_ui from toolbox import get_conf, update_ui, is_any_api_key, select_api_key
proxies, API_URL, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \ proxies, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \
get_conf('proxies', 'API_URL', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY') get_conf('proxies', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY')
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \ timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。' '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
@ -60,7 +60,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
while True: while True:
try: try:
# make a POST request to the API endpoint, stream=False # make a POST request to the API endpoint, stream=False
response = requests.post(API_URL, headers=headers, proxies=proxies, response = requests.post(llm_kwargs['endpoint'], headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
except requests.exceptions.ReadTimeout as e: except requests.exceptions.ReadTimeout as e:
retry += 1 retry += 1
@ -113,14 +113,14 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
chatbot 为WebUI中显示的对话列表修改它然后yeild出去可以直接修改对话界面内容 chatbot 为WebUI中显示的对话列表修改它然后yeild出去可以直接修改对话界面内容
additional_fn代表点击的哪个按钮按钮见functional.py additional_fn代表点击的哪个按钮按钮见functional.py
""" """
if inputs.startswith('sk-') and len(inputs) == 51: if is_any_api_key(inputs):
chatbot._cookies['api_key'] = inputs chatbot._cookies['api_key'] = inputs
chatbot.append(("输入已识别为openai的api_key", "api_key已导入")) chatbot.append(("输入已识别为openai的api_key", "api_key已导入"))
yield from update_ui(chatbot=chatbot, history=history, msg="api_key已导入") # 刷新界面 yield from update_ui(chatbot=chatbot, history=history, msg="api_key已导入") # 刷新界面
return return
elif len(chatbot._cookies['api_key']) != 51: elif not is_any_api_key(chatbot._cookies['api_key']):
chatbot.append((inputs, "缺少api_key。\n\n1. 临时解决方案直接在输入区键入api_key然后回车提交。\n\n2. 长效解决方案在config.py中配置。")) chatbot.append((inputs, "缺少api_key。\n\n1. 临时解决方案直接在输入区键入api_key然后回车提交。\n\n2. 长效解决方案在config.py中配置。"))
yield from update_ui(chatbot=chatbot, history=history, msg="api_key已导入") # 刷新界面 yield from update_ui(chatbot=chatbot, history=history, msg="缺少api_key") # 刷新界面
return return
if additional_fn is not None: if additional_fn is not None:
@ -143,7 +143,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
while True: while True:
try: try:
# make a POST request to the API endpoint, stream=True # make a POST request to the API endpoint, stream=True
response = requests.post(API_URL, headers=headers, proxies=proxies, response = requests.post(llm_kwargs['endpoint'], headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
except: except:
retry += 1 retry += 1
@ -202,12 +202,14 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
""" """
整合所有信息选择LLM模型生成http请求为发送请求做准备 整合所有信息选择LLM模型生成http请求为发送请求做准备
""" """
if len(llm_kwargs['api_key']) != 51: if not is_any_api_key(llm_kwargs['api_key']):
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案直接在输入区键入api_key然后回车提交。\n\n2. 长效解决方案在config.py中配置。") raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案直接在输入区键入api_key然后回车提交。\n\n2. 长效解决方案在config.py中配置。")
api_key = select_api_key(llm_kwargs['api_key'], llm_kwargs['llm_model'])
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {llm_kwargs['api_key']}" "Authorization": f"Bearer {api_key}"
} }
conversation_cnt = len(history) // 2 conversation_cnt = len(history) // 2
@ -235,7 +237,7 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
messages.append(what_i_ask_now) messages.append(what_i_ask_now)
payload = { payload = {
"model": llm_kwargs['llm_model'], "model": llm_kwargs['llm_model'].strip('api2d-'),
"messages": messages, "messages": messages,
"temperature": llm_kwargs['temperature'], # 1.0, "temperature": llm_kwargs['temperature'], # 1.0,
"top_p": llm_kwargs['top_p'], # 1.0, "top_p": llm_kwargs['top_p'], # 1.0,

View File

@ -1,13 +1,10 @@
import markdown import markdown
import mdtex2html
import threading
import importlib import importlib
import traceback import traceback
import inspect import inspect
import re import re
from latex2mathml.converter import convert as tex2mathml from latex2mathml.converter import convert as tex2mathml
from functools import wraps, lru_cache from functools import wraps, lru_cache
############################### 插件输入输出接驳区 ####################################### ############################### 插件输入输出接驳区 #######################################
class ChatBotWithCookies(list): class ChatBotWithCookies(list):
def __init__(self, cookie): def __init__(self, cookie):
@ -25,9 +22,10 @@ class ChatBotWithCookies(list):
def ArgsGeneralWrapper(f): def ArgsGeneralWrapper(f):
""" """
装饰器函数用于重组输入参数改变输入参数的顺序与结构 装饰器函数用于重组输入参数改变输入参数的顺序与结构
""" """
def decorated(cookies, max_length, llm_model, txt, txt2, top_p, temperature, chatbot, history, system_prompt, *args): def decorated(cookies, max_length, llm_model, txt, txt2, top_p, temperature, chatbot, history, system_prompt, *args):
from request_llm.bridge_all import model_info
txt_passon = txt txt_passon = txt
if txt == "" and txt2 != "": txt_passon = txt2 if txt == "" and txt2 != "": txt_passon = txt2
# 引入一个有cookie的chatbot # 引入一个有cookie的chatbot
@ -38,6 +36,7 @@ def ArgsGeneralWrapper(f):
llm_kwargs = { llm_kwargs = {
'api_key': cookies['api_key'], 'api_key': cookies['api_key'],
'llm_model': llm_model, 'llm_model': llm_model,
'endpoint': model_info[llm_model]['endpoint'],
'top_p':top_p, 'top_p':top_p,
'max_length': max_length, 'max_length': max_length,
'temperature':temperature, 'temperature':temperature,
@ -56,69 +55,10 @@ def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面
""" """
assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时可用clear将其清空然后用for+append循环重新赋值。" assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时可用clear将其清空然后用for+append循环重新赋值。"
yield chatbot.get_cookies(), chatbot, history, msg yield chatbot.get_cookies(), chatbot, history, msg
############################### ################## #######################################
##########################################################################################
def get_reduce_token_percent(text):
"""
* 此函数未来将被弃用
"""
try:
# text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
pattern = r"(\d+)\s+tokens\b"
match = re.findall(pattern, text)
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
max_limit = float(match[0]) - EXCEED_ALLO
current_tokens = float(match[1])
ratio = max_limit/current_tokens
assert ratio > 0 and ratio < 1
return ratio, str(int(current_tokens-max_limit))
except:
return 0.5, '不详'
def write_results_to_file(history, file_name=None):
"""
将对话记录history以Markdown格式写入文件中如果没有指定文件名则使用当前时间生成文件名
"""
import os
import time
if file_name is None:
# file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
file_name = 'chatGPT分析报告' + \
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
os.makedirs('./gpt_log/', exist_ok=True)
with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
f.write('# chatGPT 分析报告\n')
for i, content in enumerate(history):
try: # 这个bug没找到触发条件暂时先这样顶一下
if type(content) != str:
content = str(content)
except:
continue
if i % 2 == 0:
f.write('## ')
f.write(content)
f.write('\n\n')
res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
print(res)
return res
def regular_txt_to_markdown(text):
"""
将普通文本转换为Markdown格式的文本
"""
text = text.replace('\n', '\n\n')
text = text.replace('\n\n\n', '\n\n')
text = text.replace('\n\n\n', '\n\n')
return text
def CatchException(f): def CatchException(f):
""" """
装饰器函数捕捉函数f中的异常并封装到一个生成器中返回并显示到聊天当中 装饰器函数捕捉函数f中的异常并封装到一个生成器中返回并显示到聊天当中
""" """
@wraps(f) @wraps(f)
def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT): def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
@ -155,9 +95,70 @@ def HotReload(f):
return decorated return decorated
####################################### 其他小工具 #####################################
def get_reduce_token_percent(text):
"""
* 此函数未来将被弃用
"""
try:
# text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
pattern = r"(\d+)\s+tokens\b"
match = re.findall(pattern, text)
EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
max_limit = float(match[0]) - EXCEED_ALLO
current_tokens = float(match[1])
ratio = max_limit/current_tokens
assert ratio > 0 and ratio < 1
return ratio, str(int(current_tokens-max_limit))
except:
return 0.5, '不详'
def write_results_to_file(history, file_name=None):
"""
将对话记录history以Markdown格式写入文件中如果没有指定文件名则使用当前时间生成文件名
"""
import os
import time
if file_name is None:
# file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
file_name = 'chatGPT分析报告' + \
time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
os.makedirs('./gpt_log/', exist_ok=True)
with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
f.write('# chatGPT 分析报告\n')
for i, content in enumerate(history):
try: # 这个bug没找到触发条件暂时先这样顶一下
if type(content) != str:
content = str(content)
except:
continue
if i % 2 == 0:
f.write('## ')
f.write(content)
f.write('\n\n')
res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
print(res)
return res
def regular_txt_to_markdown(text):
"""
将普通文本转换为Markdown格式的文本
"""
text = text.replace('\n', '\n\n')
text = text.replace('\n\n\n', '\n\n')
text = text.replace('\n\n\n', '\n\n')
return text
def report_execption(chatbot, history, a, b): def report_execption(chatbot, history, a, b):
""" """
向chatbot中添加错误信息 向chatbot中添加错误信息
""" """
chatbot.append((a, b)) chatbot.append((a, b))
history.append(a) history.append(a)
@ -166,7 +167,7 @@ def report_execption(chatbot, history, a, b):
def text_divide_paragraph(text): def text_divide_paragraph(text):
""" """
将文本按照段落分隔符分割开生成带有段落标签的HTML代码 将文本按照段落分隔符分割开生成带有段落标签的HTML代码
""" """
if '```' in text: if '```' in text:
# careful input # careful input
@ -182,7 +183,7 @@ def text_divide_paragraph(text):
def markdown_convertion(txt): def markdown_convertion(txt):
""" """
将Markdown格式的文本转换为HTML格式如果包含数学公式则先将公式转换为HTML格式 将Markdown格式的文本转换为HTML格式如果包含数学公式则先将公式转换为HTML格式
""" """
pre = '<div class="markdown-body">' pre = '<div class="markdown-body">'
suf = '</div>' suf = '</div>'
@ -274,7 +275,7 @@ def close_up_code_segment_during_stream(gpt_reply):
def format_io(self, y): def format_io(self, y):
""" """
将输入和输出解析为HTML格式将y中最后一项的输入部分段落化并将输出部分的Markdown和数学公式转换为HTML格式 将输入和输出解析为HTML格式将y中最后一项的输入部分段落化并将输出部分的Markdown和数学公式转换为HTML格式
""" """
if y is None or y == []: if y is None or y == []:
return [] return []
@ -290,7 +291,7 @@ def format_io(self, y):
def find_free_port(): def find_free_port():
""" """
返回当前系统中可用的未使用端口 返回当前系统中可用的未使用端口
""" """
import socket import socket
from contextlib import closing from contextlib import closing
@ -410,9 +411,43 @@ def on_report_generated(files, chatbot):
return report_files, chatbot return report_files, chatbot
def is_openai_api_key(key): def is_openai_api_key(key):
# 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合
API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", key) API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
return API_MATCH return bool(API_MATCH)
def is_api2d_key(key):
if key.startswith('fk') and len(key) == 41:
return True
else:
return False
def is_any_api_key(key):
if ',' in key:
keys = key.split(',')
for k in keys:
if is_any_api_key(k): return True
return False
else:
return is_openai_api_key(key) or is_api2d_key(key)
def select_api_key(keys, llm_model):
import random
avail_key_list = []
key_list = keys.split(',')
if llm_model.startswith('gpt-'):
for k in key_list:
if is_openai_api_key(k): avail_key_list.append(k)
if llm_model.startswith('api2d-'):
for k in key_list:
if is_api2d_key(k): avail_key_list.append(k)
if len(avail_key_list) == 0:
raise RuntimeError(f"您提供的api-key不满足要求不包含任何可用于{llm_model}的api-key。")
api_key = random.choice(avail_key_list) # 随机负载均衡
return api_key
@lru_cache(maxsize=128) @lru_cache(maxsize=128)
def read_single_conf_with_lru_cache(arg): def read_single_conf_with_lru_cache(arg):
@ -423,7 +458,7 @@ def read_single_conf_with_lru_cache(arg):
r = getattr(importlib.import_module('config'), arg) r = getattr(importlib.import_module('config'), arg)
# 在读取API_KEY时检查一下是不是忘了改config # 在读取API_KEY时检查一下是不是忘了改config
if arg == 'API_KEY': if arg == 'API_KEY':
if is_openai_api_key(r): if is_any_api_key(r):
print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功") print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
else: else:
print亮红( "[API_KEY] 正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合请在config文件中修改API密钥, 添加海外代理之后再运行。" + \ print亮红( "[API_KEY] 正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合请在config文件中修改API密钥, 添加海外代理之后再运行。" + \