修复AZURE_CFG_ARRAY使用时不给定apikey报错的问题
This commit is contained in:
parent
50ecb45d63
commit
e596bb6fff
@ -233,7 +233,7 @@ NUM_CUSTOM_BASIC_BTN = 4
|
|||||||
│ ├── AZURE_ENGINE
|
│ ├── AZURE_ENGINE
|
||||||
│ └── API_URL_REDIRECT
|
│ └── API_URL_REDIRECT
|
||||||
│
|
│
|
||||||
├── "azure-gpt-3.5" 等azure模型(多个azure模型,需要动态切换)
|
├── "azure-gpt-3.5" 等azure模型(多个azure模型,需要动态切换,高优先级)
|
||||||
│ └── AZURE_CFG_ARRAY
|
│ └── AZURE_CFG_ARRAY
|
||||||
│
|
│
|
||||||
├── "spark" 星火认知大模型 spark & sparkv2
|
├── "spark" 星火认知大模型 spark & sparkv2
|
||||||
|
@ -56,6 +56,17 @@ def decode_chunk(chunk):
|
|||||||
pass
|
pass
|
||||||
return chunk_decoded, chunkjson, has_choices, has_content, has_role
|
return chunk_decoded, chunkjson, has_choices, has_content, has_role
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def verify_endpoint(endpoint):
|
||||||
|
"""
|
||||||
|
检查endpoint是否可用
|
||||||
|
"""
|
||||||
|
if "你亲手写的api名称" in endpoint:
|
||||||
|
raise ValueError("Endpoint不正确, 请检查AZURE_ENDPOINT的配置! 当前的Endpoint为:" + endpoint)
|
||||||
|
print(endpoint)
|
||||||
|
return endpoint
|
||||||
|
|
||||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False):
|
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False):
|
||||||
"""
|
"""
|
||||||
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
|
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
|
||||||
@ -77,7 +88,7 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
try:
|
try:
|
||||||
# make a POST request to the API endpoint, stream=False
|
# make a POST request to the API endpoint, stream=False
|
||||||
from .bridge_all import model_info
|
from .bridge_all import model_info
|
||||||
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
|
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
|
||||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
|
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
|
||||||
except requests.exceptions.ReadTimeout as e:
|
except requests.exceptions.ReadTimeout as e:
|
||||||
@ -169,14 +180,22 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 检查endpoint是否合法
|
||||||
|
try:
|
||||||
|
from .bridge_all import model_info
|
||||||
|
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
|
||||||
|
except:
|
||||||
|
tb_str = '```\n' + trimmed_format_exc() + '```'
|
||||||
|
chatbot[-1] = (inputs, tb_str)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="Endpoint不满足要求") # 刷新界面
|
||||||
|
return
|
||||||
|
|
||||||
history.append(inputs); history.append("")
|
history.append(inputs); history.append("")
|
||||||
|
|
||||||
retry = 0
|
retry = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# make a POST request to the API endpoint, stream=True
|
# make a POST request to the API endpoint, stream=True
|
||||||
from .bridge_all import model_info
|
|
||||||
endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
|
|
||||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
|
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
|
||||||
except:
|
except:
|
||||||
|
13
toolbox.py
13
toolbox.py
@ -621,10 +621,21 @@ def on_report_generated(cookies, files, chatbot):
|
|||||||
|
|
||||||
def load_chat_cookies():
|
def load_chat_cookies():
|
||||||
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
|
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
|
||||||
DARK_MODE, NUM_CUSTOM_BASIC_BTN = get_conf('DARK_MODE', 'NUM_CUSTOM_BASIC_BTN')
|
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN')
|
||||||
|
|
||||||
|
# deal with azure openai key
|
||||||
if is_any_api_key(AZURE_API_KEY):
|
if is_any_api_key(AZURE_API_KEY):
|
||||||
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY
|
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY
|
||||||
else: API_KEY = AZURE_API_KEY
|
else: API_KEY = AZURE_API_KEY
|
||||||
|
if len(AZURE_CFG_ARRAY) > 0:
|
||||||
|
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
|
||||||
|
if not azure_model_name.startswith('azure'):
|
||||||
|
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
|
||||||
|
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
|
||||||
|
if is_any_api_key(AZURE_API_KEY_):
|
||||||
|
if is_any_api_key(API_KEY): API_KEY = API_KEY + ',' + AZURE_API_KEY_
|
||||||
|
else: API_KEY = AZURE_API_KEY_
|
||||||
|
|
||||||
customize_fn_overwrite_ = {}
|
customize_fn_overwrite_ = {}
|
||||||
for k in range(NUM_CUSTOM_BASIC_BTN):
|
for k in range(NUM_CUSTOM_BASIC_BTN):
|
||||||
customize_fn_overwrite_.update({
|
customize_fn_overwrite_.update({
|
||||||
|
Loading…
x
Reference in New Issue
Block a user