From 116b7ce12fa90b1488fe5cde3773b9b96eb46873 Mon Sep 17 00:00:00 2001 From: binary-husky Date: Fri, 1 Sep 2023 10:34:26 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=98=9F=E7=81=AB=E8=AE=A4?= =?UTF-8?q?=E7=9F=A5=E5=A4=A7=E6=A8=A1=E5=9E=8Bv2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 4 ++-- request_llm/bridge_all.py | 16 ++++++++++++++++ request_llm/bridge_spark.py | 1 + request_llm/com_sparkapi.py | 11 +++++++++-- 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/config.py b/config.py index 876a164..dc9ce88 100644 --- a/config.py +++ b/config.py @@ -73,7 +73,7 @@ LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓ AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss", "newbing", "stack-claude"] # P.S. 其他可用的模型还包括 ["qianfan", "llama2", "qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", -# "spark", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"] +# "spark", "sparkv2", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"] # 百度千帆(LLM_MODEL="qianfan") @@ -189,7 +189,7 @@ GROBID_URLS = [ │ ├── AZURE_ENGINE │ └── API_URL_REDIRECT │ -├── "spark" 星火认知大模型 +├── "spark" 星火认知大模型 spark & sparkv2 │ ├── XFYUN_APPID │ ├── XFYUN_API_SECRET │ └── XFYUN_API_KEY diff --git a/request_llm/bridge_all.py b/request_llm/bridge_all.py index e167825..bb325e4 100644 --- a/request_llm/bridge_all.py +++ b/request_llm/bridge_all.py @@ -398,6 +398,22 @@ if "spark" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型 }) except: print(trimmed_format_exc()) +if "sparkv2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型 + try: + from .bridge_spark import predict_no_ui_long_connection as spark_noui + from .bridge_spark import predict as spark_ui + model_info.update({ + "sparkv2": { + "fn_with_ui": spark_ui, + "fn_without_ui": spark_noui, + "endpoint": None, + "max_token": 4096, + "tokenizer": tokenizer_gpt35, + "token_cnt": get_token_num_gpt35, + } + }) + except: + print(trimmed_format_exc()) if "llama2" in AVAIL_LLM_MODELS: # llama2 try: from .bridge_llama2 import predict_no_ui_long_connection as llama2_noui diff --git a/request_llm/bridge_spark.py b/request_llm/bridge_spark.py index 551b6f3..1a3d43d 100644 --- a/request_llm/bridge_spark.py +++ b/request_llm/bridge_spark.py @@ -30,6 +30,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp 函数的说明请见 request_llm/bridge_all.py """ chatbot.append((inputs, "")) + yield from update_ui(chatbot=chatbot, history=history) if additional_fn is not None: from core_functional import handle_core_functionality diff --git a/request_llm/com_sparkapi.py b/request_llm/com_sparkapi.py index c83710b..308aa64 100644 --- a/request_llm/com_sparkapi.py +++ b/request_llm/com_sparkapi.py @@ -63,6 +63,8 @@ class SparkRequestInstance(): self.api_secret = XFYUN_API_SECRET self.api_key = XFYUN_API_KEY self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat" + self.gpt_url_v2 = "ws://spark-api.xf-yun.com/v2.1/chat" + self.time_to_yield_event = threading.Event() self.time_to_exit_event = threading.Event() @@ -83,7 +85,12 @@ class SparkRequestInstance(): def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt): - wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.gpt_url) + if llm_kwargs['llm_model'] == 'sparkv2': + gpt_url = self.gpt_url_v2 + else: + gpt_url = self.gpt_url + + wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url) websocket.enableTrace(False) wsUrl = wsParam.create_url() @@ -167,7 +174,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt): }, "parameter": { "chat": { - "domain": "general", + "domain": "generalv2" if llm_kwargs['llm_model'] == 'sparkv2' else "general", "temperature": llm_kwargs["temperature"], "random_threshold": 0.5, "max_tokens": 4096,