diff --git a/config.py b/config.py index 2a8687d..56c8ea3 100644 --- a/config.py +++ b/config.py @@ -233,7 +233,7 @@ NUM_CUSTOM_BASIC_BTN = 4 │ ├── AZURE_ENGINE │ └── API_URL_REDIRECT │ -├── "azure-gpt-3.5" 等azure模型(多个azure模型,需要动态切换) +├── "azure-gpt-3.5" 等azure模型(多个azure模型,需要动态切换,高优先级) │ └── AZURE_CFG_ARRAY │ ├── "spark" 星火认知大模型 spark & sparkv2 diff --git a/request_llm/bridge_chatgpt.py b/request_llm/bridge_chatgpt.py index 5568b1a..ef81dd8 100644 --- a/request_llm/bridge_chatgpt.py +++ b/request_llm/bridge_chatgpt.py @@ -56,6 +56,17 @@ def decode_chunk(chunk): pass return chunk_decoded, chunkjson, has_choices, has_content, has_role +from functools import lru_cache +@lru_cache(maxsize=32) +def verify_endpoint(endpoint): + """ + 检查endpoint是否可用 + """ + if "你亲手写的api名称" in endpoint: + raise ValueError("Endpoint不正确, 请检查AZURE_ENDPOINT的配置! 当前的Endpoint为:" + endpoint) + print(endpoint) + return endpoint + def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False): """ 发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。 @@ -77,7 +88,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", try: # make a POST request to the API endpoint, stream=False from .bridge_all import model_info - endpoint = model_info[llm_kwargs['llm_model']]['endpoint'] + endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint']) response = requests.post(endpoint, headers=headers, proxies=proxies, json=payload, stream=True, timeout=TIMEOUT_SECONDS); break except requests.exceptions.ReadTimeout as e: @@ -169,14 +180,22 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面 return + # 检查endpoint是否合法 + try: + from .bridge_all import model_info + endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint']) + except: + tb_str = '```\n' + trimmed_format_exc() + '```' + chatbot[-1] = (inputs, tb_str) + yield from update_ui(chatbot=chatbot, history=history, msg="Endpoint不满足要求") # 刷新界面 + return + history.append(inputs); history.append("") retry = 0 while True: try: # make a POST request to the API endpoint, stream=True - from .bridge_all import model_info - endpoint = model_info[llm_kwargs['llm_model']]['endpoint'] response = requests.post(endpoint, headers=headers, proxies=proxies, json=payload, stream=True, timeout=TIMEOUT_SECONDS);break except: diff --git a/toolbox.py b/toolbox.py index cd6cd1c..07a9fda 100644 --- a/toolbox.py +++ b/toolbox.py @@ -621,10 +621,21 @@ def on_report_generated(cookies, files, chatbot): def load_chat_cookies(): API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY') - DARK_MODE, NUM_CUSTOM_BASIC_BTN = get_conf('DARK_MODE', 'NUM_CUSTOM_BASIC_BTN') + AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN') + + # deal with azure openai key if is_any_api_key(AZURE_API_KEY): if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY else: API_KEY = AZURE_API_KEY + if len(AZURE_CFG_ARRAY) > 0: + for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items(): + if not azure_model_name.startswith('azure'): + raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头") + AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"] + if is_any_api_key(AZURE_API_KEY_): + if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY_ + else: API_KEY = AZURE_API_KEY_ + customize_fn_overwrite_ = {} for k in range(NUM_CUSTOM_BASIC_BTN): customize_fn_overwrite_.update({