From 00076cc6f4dfa3643e45633e99f4d7fa44294627 Mon Sep 17 00:00:00 2001 From: qingxu fu <505030475@qq.com> Date: Wed, 25 Oct 2023 11:48:28 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=AE=AF=E9=A3=9E=E6=98=9F?= =?UTF-8?q?=E7=81=ABv3=20=EF=BC=88sparkv3=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 2 +- request_llm/bridge_all.py | 16 ++++++++++++++++ request_llm/com_sparkapi.py | 10 +++++++++- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/config.py b/config.py index f25b119..0b4d119 100644 --- a/config.py +++ b/config.py @@ -86,7 +86,7 @@ LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓ AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4", "chatglm", "moss", "newbing", "stack-claude"] # P.S. 其他可用的模型还包括 ["qianfan", "llama2", "qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random" -# "spark", "sparkv2", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"] +# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"] # 百度千帆(LLM_MODEL="qianfan") diff --git a/request_llm/bridge_all.py b/request_llm/bridge_all.py index 0639951..8e2bacb 100644 --- a/request_llm/bridge_all.py +++ b/request_llm/bridge_all.py @@ -442,6 +442,22 @@ if "sparkv2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型 }) except: print(trimmed_format_exc()) +if "sparkv3" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型 + try: + from .bridge_spark import predict_no_ui_long_connection as spark_noui + from .bridge_spark import predict as spark_ui + model_info.update({ + "sparkv3": { + "fn_with_ui": spark_ui, + "fn_without_ui": spark_noui, + "endpoint": None, + "max_token": 4096, + "tokenizer": tokenizer_gpt35, + "token_cnt": get_token_num_gpt35, + } + }) + except: + print(trimmed_format_exc()) if "llama2" in AVAIL_LLM_MODELS: # llama2 try: from .bridge_llama2 import predict_no_ui_long_connection as llama2_noui diff --git a/request_llm/com_sparkapi.py b/request_llm/com_sparkapi.py index ae970b9..5c1a3a4 100644 --- a/request_llm/com_sparkapi.py +++ b/request_llm/com_sparkapi.py @@ -64,6 +64,7 @@ class SparkRequestInstance(): self.api_key = XFYUN_API_KEY self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat" self.gpt_url_v2 = "ws://spark-api.xf-yun.com/v2.1/chat" + self.gpt_url_v3 = "ws://spark-api.xf-yun.com/v3.1/chat" self.time_to_yield_event = threading.Event() self.time_to_exit_event = threading.Event() @@ -87,6 +88,8 @@ class SparkRequestInstance(): def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt): if llm_kwargs['llm_model'] == 'sparkv2': gpt_url = self.gpt_url_v2 + elif llm_kwargs['llm_model'] == 'sparkv3': + gpt_url = self.gpt_url_v3 else: gpt_url = self.gpt_url @@ -168,6 +171,11 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt): """ 通过appid和用户的提问来生成请参数 """ + domains = { + "spark": "general", + "sparkv2": "generalv2", + "sparkv3": "generalv3", + } data = { "header": { "app_id": appid, @@ -175,7 +183,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt): }, "parameter": { "chat": { - "domain": "generalv2" if llm_kwargs['llm_model'] == 'sparkv2' else "general", + "domain": domains[llm_kwargs['llm_model']], "temperature": llm_kwargs["temperature"], "random_threshold": 0.5, "max_tokens": 4096,