diff --git a/config.py b/config.py index f25b119..0b4d119 100644 --- a/config.py +++ b/config.py @@ -86,7 +86,7 @@ LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓ AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4", "chatglm", "moss", "newbing", "stack-claude"] # P.S. 其他可用的模型还包括 ["qianfan", "llama2", "qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random" -# "spark", "sparkv2", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"] +# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"] # 百度千帆(LLM_MODEL="qianfan") diff --git a/request_llm/bridge_all.py b/request_llm/bridge_all.py index 0639951..8e2bacb 100644 --- a/request_llm/bridge_all.py +++ b/request_llm/bridge_all.py @@ -442,6 +442,22 @@ if "sparkv2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型 }) except: print(trimmed_format_exc()) +if "sparkv3" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型 + try: + from .bridge_spark import predict_no_ui_long_connection as spark_noui + from .bridge_spark import predict as spark_ui + model_info.update({ + "sparkv3": { + "fn_with_ui": spark_ui, + "fn_without_ui": spark_noui, + "endpoint": None, + "max_token": 4096, + "tokenizer": tokenizer_gpt35, + "token_cnt": get_token_num_gpt35, + } + }) + except: + print(trimmed_format_exc()) if "llama2" in AVAIL_LLM_MODELS: # llama2 try: from .bridge_llama2 import predict_no_ui_long_connection as llama2_noui diff --git a/request_llm/com_sparkapi.py b/request_llm/com_sparkapi.py index ae970b9..5c1a3a4 100644 --- a/request_llm/com_sparkapi.py +++ b/request_llm/com_sparkapi.py @@ -64,6 +64,7 @@ class SparkRequestInstance(): self.api_key = XFYUN_API_KEY self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat" self.gpt_url_v2 = "ws://spark-api.xf-yun.com/v2.1/chat" + self.gpt_url_v3 = "ws://spark-api.xf-yun.com/v3.1/chat" self.time_to_yield_event = threading.Event() self.time_to_exit_event = threading.Event() @@ -87,6 +88,8 @@ class SparkRequestInstance(): def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt): if llm_kwargs['llm_model'] == 'sparkv2': gpt_url = self.gpt_url_v2 + elif llm_kwargs['llm_model'] == 'sparkv3': + gpt_url = self.gpt_url_v3 else: gpt_url = self.gpt_url @@ -168,6 +171,11 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt): """ 通过appid和用户的提问来生成请参数 """ + domains = { + "spark": "general", + "sparkv2": "generalv2", + "sparkv3": "generalv3", + } data = { "header": { "app_id": appid, @@ -175,7 +183,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt): }, "parameter": { "chat": { - "domain": "generalv2" if llm_kwargs['llm_model'] == 'sparkv2' else "general", + "domain": domains[llm_kwargs['llm_model']], "temperature": llm_kwargs["temperature"], "random_threshold": 0.5, "max_tokens": 4096,