更多模型切换
This commit is contained in:
		
							parent
							
								
									03ba072c16
								
							
						
					
					
						commit
						9bd8511ba4
					
				@ -46,14 +46,12 @@ WEB_PORT = -1
 | 
				
			|||||||
MAX_RETRY = 2
 | 
					MAX_RETRY = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# OpenAI模型选择是(gpt4现在只对申请成功的人开放)
 | 
					# OpenAI模型选择是(gpt4现在只对申请成功的人开放)
 | 
				
			||||||
LLM_MODEL = "gpt-3.5-turbo" # 可选 "chatglm", "tgui:anymodel@localhost:7865"
 | 
					LLM_MODEL = "gpt-3.5-turbo" # 可选 "chatglm"
 | 
				
			||||||
 | 
					AVAIL_LLM_MODELS = ["gpt-3.5-turbo", "chatglm", "gpt-4", "api2d-gpt-4", "api2d-gpt-3.5-turbo"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 本地LLM模型如ChatGLM的执行方式 CPU/GPU
 | 
					# 本地LLM模型如ChatGLM的执行方式 CPU/GPU
 | 
				
			||||||
LOCAL_MODEL_DEVICE = "cpu" # 可选 "cuda"
 | 
					LOCAL_MODEL_DEVICE = "cpu" # 可选 "cuda"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# OpenAI的API_URL
 | 
					 | 
				
			||||||
API_URL = "https://api.openai.com/v1/chat/completions"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# 设置gradio的并行线程数(不需要修改)
 | 
					# 设置gradio的并行线程数(不需要修改)
 | 
				
			||||||
CONCURRENT_COUNT = 100
 | 
					CONCURRENT_COUNT = 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										6
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								main.py
									
									
									
									
									
								
							@ -5,8 +5,8 @@ def main():
 | 
				
			|||||||
    from request_llm.bridge_all import predict
 | 
					    from request_llm.bridge_all import predict
 | 
				
			||||||
    from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, DummyWith
 | 
					    from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, DummyWith
 | 
				
			||||||
    # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
 | 
					    # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
 | 
				
			||||||
    proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY = \
 | 
					    proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY, AVAIL_LLM_MODELS = \
 | 
				
			||||||
        get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY')
 | 
					        get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY', 'AVAIL_LLM_MODELS')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # 如果WEB_PORT是-1, 则随机选取WEB端口
 | 
					    # 如果WEB_PORT是-1, 则随机选取WEB端口
 | 
				
			||||||
    PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
 | 
					    PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
 | 
				
			||||||
@ -101,7 +101,7 @@ def main():
 | 
				
			|||||||
                    temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0, step=0.01, interactive=True, label="Temperature",)
 | 
					                    temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0, step=0.01, interactive=True, label="Temperature",)
 | 
				
			||||||
                    max_length_sl = gr.Slider(minimum=256, maximum=4096, value=512, step=1, interactive=True, label="MaxLength",)
 | 
					                    max_length_sl = gr.Slider(minimum=256, maximum=4096, value=512, step=1, interactive=True, label="MaxLength",)
 | 
				
			||||||
                    checkboxes = gr.CheckboxGroup(["基础功能区", "函数插件区", "底部输入区", "输入清除键"], value=["基础功能区", "函数插件区"], label="显示/隐藏功能区")
 | 
					                    checkboxes = gr.CheckboxGroup(["基础功能区", "函数插件区", "底部输入区", "输入清除键"], value=["基础功能区", "函数插件区"], label="显示/隐藏功能区")
 | 
				
			||||||
                    md_dropdown = gr.Dropdown(["gpt-3.5-turbo", "chatglm"], value=LLM_MODEL, label="").style(container=False)
 | 
					                    md_dropdown = gr.Dropdown(AVAIL_LLM_MODELS, value=LLM_MODEL, label="").style(container=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    gr.Markdown(description)
 | 
					                    gr.Markdown(description)
 | 
				
			||||||
                with gr.Accordion("备选输入区", open=True, visible=False) as area_input_secondary:
 | 
					                with gr.Accordion("备选输入区", open=True, visible=False) as area_input_secondary:
 | 
				
			||||||
 | 
				
			|||||||
@ -21,38 +21,42 @@ from .bridge_chatglm import predict as chatglm_ui
 | 
				
			|||||||
from .bridge_tgui import predict_no_ui_long_connection as tgui_noui
 | 
					from .bridge_tgui import predict_no_ui_long_connection as tgui_noui
 | 
				
			||||||
from .bridge_tgui import predict as tgui_ui
 | 
					from .bridge_tgui import predict as tgui_ui
 | 
				
			||||||
 | 
					
 | 
				
			||||||
methods = {
 | 
					colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
 | 
				
			||||||
    "openai-no-ui": chatgpt_noui,
 | 
					 | 
				
			||||||
    "openai-ui": chatgpt_ui,
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    "chatglm-no-ui": chatglm_noui,
 | 
					 | 
				
			||||||
    "chatglm-ui": chatglm_ui,
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    "tgui-no-ui": tgui_noui,
 | 
					 | 
				
			||||||
    "tgui-ui": tgui_ui,
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
model_info = {
 | 
					model_info = {
 | 
				
			||||||
    # openai
 | 
					    # openai
 | 
				
			||||||
    "gpt-3.5-turbo": {
 | 
					    "gpt-3.5-turbo": {
 | 
				
			||||||
 | 
					        "fn_with_ui": chatgpt_ui,
 | 
				
			||||||
 | 
					        "fn_without_ui": chatgpt_noui,
 | 
				
			||||||
 | 
					        "endpoint": "https://api.openai.com/v1/chat/completions",
 | 
				
			||||||
        "max_token": 4096,
 | 
					        "max_token": 4096,
 | 
				
			||||||
        "tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
 | 
					        "tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
 | 
				
			||||||
        "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())),
 | 
					        "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())),
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    "gpt-4": {
 | 
					    "gpt-4": {
 | 
				
			||||||
 | 
					        "fn_with_ui": chatgpt_ui,
 | 
				
			||||||
 | 
					        "fn_without_ui": chatgpt_noui,
 | 
				
			||||||
 | 
					        "endpoint": "https://api.openai.com/v1/chat/completions",
 | 
				
			||||||
        "max_token": 4096,
 | 
					        "max_token": 4096,
 | 
				
			||||||
        "tokenizer": tiktoken.encoding_for_model("gpt-4"),
 | 
					        "tokenizer": tiktoken.encoding_for_model("gpt-4"),
 | 
				
			||||||
        "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-4").encode(txt, disallowed_special=())),
 | 
					        "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-4").encode(txt, disallowed_special=())),
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # api_2d
 | 
					    # api_2d
 | 
				
			||||||
    "gpt-3.5-turbo-api2d": {
 | 
					    "api2d-gpt-3.5-turbo": {
 | 
				
			||||||
 | 
					        "fn_with_ui": chatgpt_ui,
 | 
				
			||||||
 | 
					        "fn_without_ui": chatgpt_noui,
 | 
				
			||||||
 | 
					        "endpoint": "https://openai.api2d.net/v1/chat/completions",
 | 
				
			||||||
        "max_token": 4096,
 | 
					        "max_token": 4096,
 | 
				
			||||||
        "tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
 | 
					        "tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
 | 
				
			||||||
        "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())),
 | 
					        "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())),
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    "gpt-4-api2d": {
 | 
					    "api2d-gpt-4": {
 | 
				
			||||||
 | 
					        "fn_with_ui": chatgpt_ui,
 | 
				
			||||||
 | 
					        "fn_without_ui": chatgpt_noui,
 | 
				
			||||||
 | 
					        "endpoint": "https://openai.api2d.net/v1/chat/completions",
 | 
				
			||||||
        "max_token": 4096,
 | 
					        "max_token": 4096,
 | 
				
			||||||
        "tokenizer": tiktoken.encoding_for_model("gpt-4"),
 | 
					        "tokenizer": tiktoken.encoding_for_model("gpt-4"),
 | 
				
			||||||
        "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-4").encode(txt, disallowed_special=())),
 | 
					        "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-4").encode(txt, disallowed_special=())),
 | 
				
			||||||
@ -60,18 +64,20 @@ model_info = {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # chatglm
 | 
					    # chatglm
 | 
				
			||||||
    "chatglm": {
 | 
					    "chatglm": {
 | 
				
			||||||
 | 
					        "fn_with_ui": chatglm_ui,
 | 
				
			||||||
 | 
					        "fn_without_ui": chatglm_noui,
 | 
				
			||||||
 | 
					        "endpoint": None,
 | 
				
			||||||
        "max_token": 1024,
 | 
					        "max_token": 1024,
 | 
				
			||||||
        "tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
 | 
					        "tokenizer": tiktoken.encoding_for_model("gpt-3.5-turbo"),
 | 
				
			||||||
        "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())),
 | 
					        "token_cnt": lambda txt: len(tiktoken.encoding_for_model("gpt-3.5-turbo").encode(txt, disallowed_special=())),
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def LLM_CATCH_EXCEPTION(f):
 | 
					def LLM_CATCH_EXCEPTION(f):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        装饰器函数,将错误显示出来
 | 
					    装饰器函数,将错误显示出来
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def decorated(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience):
 | 
					    def decorated(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
@ -85,21 +91,20 @@ def LLM_CATCH_EXCEPTION(f):
 | 
				
			|||||||
            return tb_str
 | 
					            return tb_str
 | 
				
			||||||
    return decorated
 | 
					    return decorated
 | 
				
			||||||
 | 
					
 | 
				
			||||||
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience=False):
 | 
					def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience=False):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        发送至LLM,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
 | 
					    发送至LLM,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
 | 
				
			||||||
        inputs:
 | 
					    inputs:
 | 
				
			||||||
            是本次问询的输入
 | 
					        是本次问询的输入
 | 
				
			||||||
        sys_prompt:
 | 
					    sys_prompt:
 | 
				
			||||||
            系统静默prompt
 | 
					        系统静默prompt
 | 
				
			||||||
        llm_kwargs:
 | 
					    llm_kwargs:
 | 
				
			||||||
            LLM的内部调优参数
 | 
					        LLM的内部调优参数
 | 
				
			||||||
        history:
 | 
					    history:
 | 
				
			||||||
            是之前的对话列表
 | 
					        是之前的对话列表
 | 
				
			||||||
        observe_window = None:
 | 
					    observe_window = None:
 | 
				
			||||||
            用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
 | 
					        用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    import threading, time, copy
 | 
					    import threading, time, copy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -109,12 +114,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, obser
 | 
				
			|||||||
        assert not model.startswith("tgui"), "TGUI不支持函数插件的实现"
 | 
					        assert not model.startswith("tgui"), "TGUI不支持函数插件的实现"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 如果只询问1个大语言模型:
 | 
					        # 如果只询问1个大语言模型:
 | 
				
			||||||
        if model.startswith('gpt'):
 | 
					        method = model_info[model]["fn_without_ui"]
 | 
				
			||||||
            method = methods['openai-no-ui']
 | 
					 | 
				
			||||||
        elif model == 'chatglm':
 | 
					 | 
				
			||||||
            method = methods['chatglm-no-ui']
 | 
					 | 
				
			||||||
        elif model.startswith('tgui'):
 | 
					 | 
				
			||||||
            method = methods['tgui-no-ui']
 | 
					 | 
				
			||||||
        return method(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience)
 | 
					        return method(inputs, llm_kwargs, history, sys_prompt, observe_window, console_slience)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        # 如果同时询问多个大语言模型:
 | 
					        # 如果同时询问多个大语言模型:
 | 
				
			||||||
@ -129,12 +129,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, obser
 | 
				
			|||||||
        futures = []
 | 
					        futures = []
 | 
				
			||||||
        for i in range(n_model):
 | 
					        for i in range(n_model):
 | 
				
			||||||
            model = models[i]
 | 
					            model = models[i]
 | 
				
			||||||
            if model.startswith('gpt'):
 | 
					            method = model_info[model]["fn_without_ui"]
 | 
				
			||||||
                method = methods['openai-no-ui']
 | 
					 | 
				
			||||||
            elif model == 'chatglm':
 | 
					 | 
				
			||||||
                method = methods['chatglm-no-ui']
 | 
					 | 
				
			||||||
            elif model.startswith('tgui'):
 | 
					 | 
				
			||||||
                method = methods['tgui-no-ui']
 | 
					 | 
				
			||||||
            llm_kwargs_feedin = copy.deepcopy(llm_kwargs)
 | 
					            llm_kwargs_feedin = copy.deepcopy(llm_kwargs)
 | 
				
			||||||
            llm_kwargs_feedin['llm_model'] = model
 | 
					            llm_kwargs_feedin['llm_model'] = model
 | 
				
			||||||
            future = executor.submit(LLM_CATCH_EXCEPTION(method), inputs, llm_kwargs_feedin, history, sys_prompt, window_mutex[i], console_slience)
 | 
					            future = executor.submit(LLM_CATCH_EXCEPTION(method), inputs, llm_kwargs_feedin, history, sys_prompt, window_mutex[i], console_slience)
 | 
				
			||||||
@ -176,20 +171,15 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history, sys_prompt, obser
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def predict(inputs, llm_kwargs, *args, **kwargs):
 | 
					def predict(inputs, llm_kwargs, *args, **kwargs):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        发送至LLM,流式获取输出。
 | 
					    发送至LLM,流式获取输出。
 | 
				
			||||||
        用于基础的对话功能。
 | 
					    用于基础的对话功能。
 | 
				
			||||||
        inputs 是本次问询的输入
 | 
					    inputs 是本次问询的输入
 | 
				
			||||||
        top_p, temperature是LLM的内部调优参数
 | 
					    top_p, temperature是LLM的内部调优参数
 | 
				
			||||||
        history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
 | 
					    history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
 | 
				
			||||||
        chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
 | 
					    chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
 | 
				
			||||||
        additional_fn代表点击的哪个按钮,按钮见functional.py
 | 
					    additional_fn代表点击的哪个按钮,按钮见functional.py
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if llm_kwargs['llm_model'].startswith('gpt'):
 | 
					 | 
				
			||||||
        method = methods['openai-ui']
 | 
					 | 
				
			||||||
    elif llm_kwargs['llm_model'] == 'chatglm':
 | 
					 | 
				
			||||||
        method = methods['chatglm-ui']
 | 
					 | 
				
			||||||
    elif llm_kwargs['llm_model'].startswith('tgui'):
 | 
					 | 
				
			||||||
        method = methods['tgui-ui']
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    method = model_info[llm_kwargs['llm_model']]["fn_with_ui"]
 | 
				
			||||||
    yield from method(inputs, llm_kwargs, *args, **kwargs)
 | 
					    yield from method(inputs, llm_kwargs, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -21,9 +21,9 @@ import importlib
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# config_private.py放自己的秘密如API和代理网址
 | 
					# config_private.py放自己的秘密如API和代理网址
 | 
				
			||||||
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
 | 
					# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
 | 
				
			||||||
from toolbox import get_conf, update_ui
 | 
					from toolbox import get_conf, update_ui, is_any_api_key, select_api_key
 | 
				
			||||||
proxies, API_URL, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \
 | 
					proxies, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \
 | 
				
			||||||
    get_conf('proxies', 'API_URL', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY')
 | 
					    get_conf('proxies', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
 | 
					timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
 | 
				
			||||||
                  '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
 | 
					                  '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
 | 
				
			||||||
@ -60,7 +60,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
 | 
				
			|||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            # make a POST request to the API endpoint, stream=False
 | 
					            # make a POST request to the API endpoint, stream=False
 | 
				
			||||||
            response = requests.post(API_URL, headers=headers, proxies=proxies,
 | 
					            response = requests.post(llm_kwargs['endpoint'], headers=headers, proxies=proxies,
 | 
				
			||||||
                                    json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
 | 
					                                    json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
 | 
				
			||||||
        except requests.exceptions.ReadTimeout as e:
 | 
					        except requests.exceptions.ReadTimeout as e:
 | 
				
			||||||
            retry += 1
 | 
					            retry += 1
 | 
				
			||||||
@ -113,14 +113,14 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
 | 
				
			|||||||
        chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
 | 
					        chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
 | 
				
			||||||
        additional_fn代表点击的哪个按钮,按钮见functional.py
 | 
					        additional_fn代表点击的哪个按钮,按钮见functional.py
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if inputs.startswith('sk-') and len(inputs) == 51:
 | 
					    if is_any_api_key(inputs):
 | 
				
			||||||
        chatbot._cookies['api_key'] = inputs
 | 
					        chatbot._cookies['api_key'] = inputs
 | 
				
			||||||
        chatbot.append(("输入已识别为openai的api_key", "api_key已导入"))
 | 
					        chatbot.append(("输入已识别为openai的api_key", "api_key已导入"))
 | 
				
			||||||
        yield from update_ui(chatbot=chatbot, history=history, msg="api_key已导入") # 刷新界面
 | 
					        yield from update_ui(chatbot=chatbot, history=history, msg="api_key已导入") # 刷新界面
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
    elif len(chatbot._cookies['api_key']) != 51:
 | 
					    elif not is_any_api_key(chatbot._cookies['api_key']):
 | 
				
			||||||
        chatbot.append((inputs, "缺少api_key。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。"))
 | 
					        chatbot.append((inputs, "缺少api_key。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。"))
 | 
				
			||||||
        yield from update_ui(chatbot=chatbot, history=history, msg="api_key已导入") # 刷新界面
 | 
					        yield from update_ui(chatbot=chatbot, history=history, msg="缺少api_key") # 刷新界面
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if additional_fn is not None:
 | 
					    if additional_fn is not None:
 | 
				
			||||||
@ -143,7 +143,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
 | 
				
			|||||||
    while True:
 | 
					    while True:
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            # make a POST request to the API endpoint, stream=True
 | 
					            # make a POST request to the API endpoint, stream=True
 | 
				
			||||||
            response = requests.post(API_URL, headers=headers, proxies=proxies,
 | 
					            response = requests.post(llm_kwargs['endpoint'], headers=headers, proxies=proxies,
 | 
				
			||||||
                                    json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
 | 
					                                    json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
 | 
				
			||||||
        except:
 | 
					        except:
 | 
				
			||||||
            retry += 1
 | 
					            retry += 1
 | 
				
			||||||
@ -202,12 +202,14 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
        整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
 | 
					        整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if len(llm_kwargs['api_key']) != 51:
 | 
					    if not is_any_api_key(llm_kwargs['api_key']):
 | 
				
			||||||
        raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
 | 
					        raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    api_key = select_api_key(llm_kwargs['api_key'], llm_kwargs['llm_model'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    headers = {
 | 
					    headers = {
 | 
				
			||||||
        "Content-Type": "application/json",
 | 
					        "Content-Type": "application/json",
 | 
				
			||||||
        "Authorization": f"Bearer {llm_kwargs['api_key']}"
 | 
					        "Authorization": f"Bearer {api_key}"
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    conversation_cnt = len(history) // 2
 | 
					    conversation_cnt = len(history) // 2
 | 
				
			||||||
@ -235,7 +237,7 @@ def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
 | 
				
			|||||||
    messages.append(what_i_ask_now)
 | 
					    messages.append(what_i_ask_now)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    payload = {
 | 
					    payload = {
 | 
				
			||||||
        "model": llm_kwargs['llm_model'],
 | 
					        "model": llm_kwargs['llm_model'].strip('api2d-'),
 | 
				
			||||||
        "messages": messages, 
 | 
					        "messages": messages, 
 | 
				
			||||||
        "temperature": llm_kwargs['temperature'],  # 1.0,
 | 
					        "temperature": llm_kwargs['temperature'],  # 1.0,
 | 
				
			||||||
        "top_p": llm_kwargs['top_p'],  # 1.0,
 | 
					        "top_p": llm_kwargs['top_p'],  # 1.0,
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										179
									
								
								toolbox.py
									
									
									
									
									
								
							
							
						
						
									
										179
									
								
								toolbox.py
									
									
									
									
									
								
							@ -1,13 +1,10 @@
 | 
				
			|||||||
import markdown
 | 
					import markdown
 | 
				
			||||||
import mdtex2html
 | 
					 | 
				
			||||||
import threading
 | 
					 | 
				
			||||||
import importlib
 | 
					import importlib
 | 
				
			||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
import inspect
 | 
					import inspect
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
from latex2mathml.converter import convert as tex2mathml
 | 
					from latex2mathml.converter import convert as tex2mathml
 | 
				
			||||||
from functools import wraps, lru_cache
 | 
					from functools import wraps, lru_cache
 | 
				
			||||||
 | 
					 | 
				
			||||||
############################### 插件输入输出接驳区 #######################################
 | 
					############################### 插件输入输出接驳区 #######################################
 | 
				
			||||||
class ChatBotWithCookies(list):
 | 
					class ChatBotWithCookies(list):
 | 
				
			||||||
    def __init__(self, cookie):
 | 
					    def __init__(self, cookie):
 | 
				
			||||||
@ -25,9 +22,10 @@ class ChatBotWithCookies(list):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def ArgsGeneralWrapper(f):
 | 
					def ArgsGeneralWrapper(f):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
 | 
					    装饰器函数,用于重组输入参数,改变输入参数的顺序与结构。
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def decorated(cookies, max_length, llm_model, txt, txt2, top_p, temperature, chatbot, history, system_prompt, *args):
 | 
					    def decorated(cookies, max_length, llm_model, txt, txt2, top_p, temperature, chatbot, history, system_prompt, *args):
 | 
				
			||||||
 | 
					        from request_llm.bridge_all import model_info
 | 
				
			||||||
        txt_passon = txt
 | 
					        txt_passon = txt
 | 
				
			||||||
        if txt == "" and txt2 != "": txt_passon = txt2
 | 
					        if txt == "" and txt2 != "": txt_passon = txt2
 | 
				
			||||||
        # 引入一个有cookie的chatbot
 | 
					        # 引入一个有cookie的chatbot
 | 
				
			||||||
@ -38,6 +36,7 @@ def ArgsGeneralWrapper(f):
 | 
				
			|||||||
        llm_kwargs = {
 | 
					        llm_kwargs = {
 | 
				
			||||||
            'api_key': cookies['api_key'],
 | 
					            'api_key': cookies['api_key'],
 | 
				
			||||||
            'llm_model': llm_model,
 | 
					            'llm_model': llm_model,
 | 
				
			||||||
 | 
					            'endpoint': model_info[llm_model]['endpoint'],
 | 
				
			||||||
            'top_p':top_p, 
 | 
					            'top_p':top_p, 
 | 
				
			||||||
            'max_length': max_length,
 | 
					            'max_length': max_length,
 | 
				
			||||||
            'temperature':temperature,
 | 
					            'temperature':temperature,
 | 
				
			||||||
@ -56,69 +55,10 @@ def update_ui(chatbot, history, msg='正常', **kwargs):  # 刷新界面
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
    assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时,可用clear将其清空,然后用for+append循环重新赋值。"
 | 
					    assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时,可用clear将其清空,然后用for+append循环重新赋值。"
 | 
				
			||||||
    yield chatbot.get_cookies(), chatbot, history, msg
 | 
					    yield chatbot.get_cookies(), chatbot, history, msg
 | 
				
			||||||
############################### ################## #######################################
 | 
					 | 
				
			||||||
##########################################################################################
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_reduce_token_percent(text):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
        * 此函数未来将被弃用
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    try:
 | 
					 | 
				
			||||||
        # text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
 | 
					 | 
				
			||||||
        pattern = r"(\d+)\s+tokens\b"
 | 
					 | 
				
			||||||
        match = re.findall(pattern, text)
 | 
					 | 
				
			||||||
        EXCEED_ALLO = 500  # 稍微留一点余地,否则在回复时会因余量太少出问题
 | 
					 | 
				
			||||||
        max_limit = float(match[0]) - EXCEED_ALLO
 | 
					 | 
				
			||||||
        current_tokens = float(match[1])
 | 
					 | 
				
			||||||
        ratio = max_limit/current_tokens
 | 
					 | 
				
			||||||
        assert ratio > 0 and ratio < 1
 | 
					 | 
				
			||||||
        return ratio, str(int(current_tokens-max_limit))
 | 
					 | 
				
			||||||
    except:
 | 
					 | 
				
			||||||
        return 0.5, '不详'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def write_results_to_file(history, file_name=None):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
        将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    import os
 | 
					 | 
				
			||||||
    import time
 | 
					 | 
				
			||||||
    if file_name is None:
 | 
					 | 
				
			||||||
        # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
					 | 
				
			||||||
        file_name = 'chatGPT分析报告' + \
 | 
					 | 
				
			||||||
            time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
					 | 
				
			||||||
    os.makedirs('./gpt_log/', exist_ok=True)
 | 
					 | 
				
			||||||
    with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
 | 
					 | 
				
			||||||
        f.write('# chatGPT 分析报告\n')
 | 
					 | 
				
			||||||
        for i, content in enumerate(history):
 | 
					 | 
				
			||||||
            try:    # 这个bug没找到触发条件,暂时先这样顶一下
 | 
					 | 
				
			||||||
                if type(content) != str:
 | 
					 | 
				
			||||||
                    content = str(content)
 | 
					 | 
				
			||||||
            except:
 | 
					 | 
				
			||||||
                continue
 | 
					 | 
				
			||||||
            if i % 2 == 0:
 | 
					 | 
				
			||||||
                f.write('## ')
 | 
					 | 
				
			||||||
            f.write(content)
 | 
					 | 
				
			||||||
            f.write('\n\n')
 | 
					 | 
				
			||||||
    res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
 | 
					 | 
				
			||||||
    print(res)
 | 
					 | 
				
			||||||
    return res
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def regular_txt_to_markdown(text):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
        将普通文本转换为Markdown格式的文本。
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    text = text.replace('\n', '\n\n')
 | 
					 | 
				
			||||||
    text = text.replace('\n\n\n', '\n\n')
 | 
					 | 
				
			||||||
    text = text.replace('\n\n\n', '\n\n')
 | 
					 | 
				
			||||||
    return text
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def CatchException(f):
 | 
					def CatchException(f):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
 | 
					    装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    @wraps(f)
 | 
					    @wraps(f)
 | 
				
			||||||
    def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
 | 
					    def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
 | 
				
			||||||
@ -155,9 +95,70 @@ def HotReload(f):
 | 
				
			|||||||
    return decorated
 | 
					    return decorated
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					####################################### 其他小工具 #####################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_reduce_token_percent(text):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					        * 此函数未来将被弃用
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        # text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
 | 
				
			||||||
 | 
					        pattern = r"(\d+)\s+tokens\b"
 | 
				
			||||||
 | 
					        match = re.findall(pattern, text)
 | 
				
			||||||
 | 
					        EXCEED_ALLO = 500  # 稍微留一点余地,否则在回复时会因余量太少出问题
 | 
				
			||||||
 | 
					        max_limit = float(match[0]) - EXCEED_ALLO
 | 
				
			||||||
 | 
					        current_tokens = float(match[1])
 | 
				
			||||||
 | 
					        ratio = max_limit/current_tokens
 | 
				
			||||||
 | 
					        assert ratio > 0 and ratio < 1
 | 
				
			||||||
 | 
					        return ratio, str(int(current_tokens-max_limit))
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					        return 0.5, '不详'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def write_results_to_file(history, file_name=None):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    import os
 | 
				
			||||||
 | 
					    import time
 | 
				
			||||||
 | 
					    if file_name is None:
 | 
				
			||||||
 | 
					        # file_name = time.strftime("chatGPT分析报告%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
				
			||||||
 | 
					        file_name = 'chatGPT分析报告' + \
 | 
				
			||||||
 | 
					            time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '.md'
 | 
				
			||||||
 | 
					    os.makedirs('./gpt_log/', exist_ok=True)
 | 
				
			||||||
 | 
					    with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f:
 | 
				
			||||||
 | 
					        f.write('# chatGPT 分析报告\n')
 | 
				
			||||||
 | 
					        for i, content in enumerate(history):
 | 
				
			||||||
 | 
					            try:    # 这个bug没找到触发条件,暂时先这样顶一下
 | 
				
			||||||
 | 
					                if type(content) != str:
 | 
				
			||||||
 | 
					                    content = str(content)
 | 
				
			||||||
 | 
					            except:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            if i % 2 == 0:
 | 
				
			||||||
 | 
					                f.write('## ')
 | 
				
			||||||
 | 
					            f.write(content)
 | 
				
			||||||
 | 
					            f.write('\n\n')
 | 
				
			||||||
 | 
					    res = '以上材料已经被写入' + os.path.abspath(f'./gpt_log/{file_name}')
 | 
				
			||||||
 | 
					    print(res)
 | 
				
			||||||
 | 
					    return res
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def regular_txt_to_markdown(text):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    将普通文本转换为Markdown格式的文本。
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    text = text.replace('\n', '\n\n')
 | 
				
			||||||
 | 
					    text = text.replace('\n\n\n', '\n\n')
 | 
				
			||||||
 | 
					    text = text.replace('\n\n\n', '\n\n')
 | 
				
			||||||
 | 
					    return text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def report_execption(chatbot, history, a, b):
 | 
					def report_execption(chatbot, history, a, b):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        向chatbot中添加错误信息
 | 
					    向chatbot中添加错误信息
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    chatbot.append((a, b))
 | 
					    chatbot.append((a, b))
 | 
				
			||||||
    history.append(a)
 | 
					    history.append(a)
 | 
				
			||||||
@ -166,7 +167,7 @@ def report_execption(chatbot, history, a, b):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def text_divide_paragraph(text):
 | 
					def text_divide_paragraph(text):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
 | 
					    将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if '```' in text:
 | 
					    if '```' in text:
 | 
				
			||||||
        # careful input
 | 
					        # careful input
 | 
				
			||||||
@ -182,7 +183,7 @@ def text_divide_paragraph(text):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def markdown_convertion(txt):
 | 
					def markdown_convertion(txt):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
 | 
					    将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    pre = '<div class="markdown-body">'
 | 
					    pre = '<div class="markdown-body">'
 | 
				
			||||||
    suf = '</div>'
 | 
					    suf = '</div>'
 | 
				
			||||||
@ -274,7 +275,7 @@ def close_up_code_segment_during_stream(gpt_reply):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def format_io(self, y):
 | 
					def format_io(self, y):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
 | 
					    将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if y is None or y == []:
 | 
					    if y is None or y == []:
 | 
				
			||||||
        return []
 | 
					        return []
 | 
				
			||||||
@ -290,7 +291,7 @@ def format_io(self, y):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def find_free_port():
 | 
					def find_free_port():
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
        返回当前系统中可用的未使用端口。
 | 
					    返回当前系统中可用的未使用端口。
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    import socket
 | 
					    import socket
 | 
				
			||||||
    from contextlib import closing
 | 
					    from contextlib import closing
 | 
				
			||||||
@ -410,9 +411,43 @@ def on_report_generated(files, chatbot):
 | 
				
			|||||||
    return report_files, chatbot
 | 
					    return report_files, chatbot
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def is_openai_api_key(key):
 | 
					def is_openai_api_key(key):
 | 
				
			||||||
    # 正确的 API_KEY 是 "sk-" + 48 位大小写字母数字的组合
 | 
					 | 
				
			||||||
    API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
 | 
					    API_MATCH = re.match(r"sk-[a-zA-Z0-9]{48}$", key)
 | 
				
			||||||
    return API_MATCH
 | 
					    return bool(API_MATCH)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_api2d_key(key):
 | 
				
			||||||
 | 
					    if key.startswith('fk') and len(key) == 41:
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_any_api_key(key):
 | 
				
			||||||
 | 
					    if ',' in key:
 | 
				
			||||||
 | 
					        keys = key.split(',')
 | 
				
			||||||
 | 
					        for k in keys:
 | 
				
			||||||
 | 
					            if is_any_api_key(k): return True
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return is_openai_api_key(key) or is_api2d_key(key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def select_api_key(keys, llm_model):
 | 
				
			||||||
 | 
					    import random
 | 
				
			||||||
 | 
					    avail_key_list = []
 | 
				
			||||||
 | 
					    key_list = keys.split(',')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if llm_model.startswith('gpt-'):
 | 
				
			||||||
 | 
					        for k in key_list:
 | 
				
			||||||
 | 
					            if is_openai_api_key(k): avail_key_list.append(k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if llm_model.startswith('api2d-'):
 | 
				
			||||||
 | 
					        for k in key_list:
 | 
				
			||||||
 | 
					            if is_api2d_key(k): avail_key_list.append(k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if len(avail_key_list) == 0:
 | 
				
			||||||
 | 
					        raise RuntimeError(f"您提供的api-key不满足要求,不包含任何可用于{llm_model}的api-key。")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    api_key = random.choice(avail_key_list) # 随机负载均衡
 | 
				
			||||||
 | 
					    return api_key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@lru_cache(maxsize=128)
 | 
					@lru_cache(maxsize=128)
 | 
				
			||||||
def read_single_conf_with_lru_cache(arg):
 | 
					def read_single_conf_with_lru_cache(arg):
 | 
				
			||||||
@ -423,7 +458,7 @@ def read_single_conf_with_lru_cache(arg):
 | 
				
			|||||||
        r = getattr(importlib.import_module('config'), arg)
 | 
					        r = getattr(importlib.import_module('config'), arg)
 | 
				
			||||||
    # 在读取API_KEY时,检查一下是不是忘了改config
 | 
					    # 在读取API_KEY时,检查一下是不是忘了改config
 | 
				
			||||||
    if arg == 'API_KEY':
 | 
					    if arg == 'API_KEY':
 | 
				
			||||||
        if is_openai_api_key(r):
 | 
					        if is_any_api_key(r):
 | 
				
			||||||
            print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
 | 
					            print亮绿(f"[API_KEY] 您的 API_KEY 是: {r[:15]}*** API_KEY 导入成功")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            print亮红( "[API_KEY] 正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
 | 
					            print亮红( "[API_KEY] 正确的 API_KEY 是 'sk-' + '48 位大小写字母数字' 的组合,请在config文件中修改API密钥, 添加海外代理之后再运行。" + \
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user