修复AZURE_CFG_ARRAY使用时不给定apikey报错的问题

This commit is contained in:
binary-husky 2023-10-28 00:29:49 +08:00
parent 50ecb45d63
commit e596bb6fff
3 changed files with 35 additions and 5 deletions

View File

@ -233,7 +233,7 @@ NUM_CUSTOM_BASIC_BTN = 4
AZURE_ENGINE
API_URL_REDIRECT
"azure-gpt-3.5" 等azure模型多个azure模型需要动态切换
"azure-gpt-3.5" 等azure模型多个azure模型需要动态切换高优先级
AZURE_CFG_ARRAY
"spark" 星火认知大模型 spark & sparkv2

View File

@ -56,6 +56,17 @@ def decode_chunk(chunk):
pass
return chunk_decoded, chunkjson, has_choices, has_content, has_role
from functools import lru_cache
@lru_cache(maxsize=32)
def verify_endpoint(endpoint):
"""
检查endpoint是否可用
"""
if "你亲手写的api名称" in endpoint:
raise ValueError("Endpoint不正确, 请检查AZURE_ENDPOINT的配置! 当前的Endpoint为:" + endpoint)
print(endpoint)
return endpoint
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False):
"""
发送至chatGPT等待回复一次性完成不显示中间过程但内部用stream的方法避免中途网线被掐
@ -77,7 +88,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
try:
# make a POST request to the API endpoint, stream=False
from .bridge_all import model_info
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
except requests.exceptions.ReadTimeout as e:
@ -169,14 +180,22 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
return
# 检查endpoint是否合法
try:
from .bridge_all import model_info
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
except:
tb_str = '```\n' + trimmed_format_exc() + '```'
chatbot[-1] = (inputs, tb_str)
yield from update_ui(chatbot=chatbot, history=history, msg="Endpoint不满足要求") # 刷新界面
return
history.append(inputs); history.append("")
retry = 0
while True:
try:
# make a POST request to the API endpoint, stream=True
from .bridge_all import model_info
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
response = requests.post(endpoint, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
except:

View File

@ -621,10 +621,21 @@ def on_report_generated(cookies, files, chatbot):
def load_chat_cookies():
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
DARK_MODE, NUM_CUSTOM_BASIC_BTN = get_conf('DARK_MODE', 'NUM_CUSTOM_BASIC_BTN')
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN')
# deal with azure openai key
if is_any_api_key(AZURE_API_KEY):
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY
else: API_KEY = AZURE_API_KEY
if len(AZURE_CFG_ARRAY) > 0:
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
if not azure_model_name.startswith('azure'):
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
if is_any_api_key(AZURE_API_KEY_):
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY_
else: API_KEY = AZURE_API_KEY_
customize_fn_overwrite_ = {}
for k in range(NUM_CUSTOM_BASIC_BTN):
customize_fn_overwrite_.update({