From af23730f8fb0831fd2d0e113948cc50657eefdeb Mon Sep 17 00:00:00 2001 From: binary-husky Date: Mon, 14 Aug 2023 03:08:15 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8E=A5=E5=85=A5=E8=AE=AF=E9=A3=9E=E6=98=9F?= =?UTF-8?q?=E7=81=ABSpark=E5=A4=A7=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 9 ++- request_llm/bridge_all.py | 26 +++++- request_llm/bridge_spark.py | 49 +++++++++++ request_llm/com_sparkapi.py | 156 ++++++++++++++++++------------------ tests/test_llms.py | 5 +- 5 files changed, 163 insertions(+), 82 deletions(-) create mode 100644 request_llm/bridge_spark.py diff --git a/config.py b/config.py index bfa4a3a..a5ae33d 100644 --- a/config.py +++ b/config.py @@ -71,7 +71,7 @@ MAX_RETRY = 2 # 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 ) LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓ AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss", "newbing", "stack-claude"] -# P.S. 其他可用的模型还包括 ["qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"] +# P.S. 其他可用的模型还包括 ["qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "spark", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"] # ChatGLM(2) Finetune Model Path (如果使用ChatGLM2微调模型,需要把"chatglmft"加入AVAIL_LLM_MODELS中) @@ -137,6 +137,13 @@ ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK ALIYUN_ACCESSKEY="" # (无需填写) ALIYUN_SECRET="" # (无需填写) + +# 接入讯飞星火大模型 https://console.xfyun.cn/services/iat +XFYUN_APPID = "00000000" +XFYUN_API_SECRET = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" +XFYUN_API_KEY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + + # Claude API KEY ANTHROPIC_API_KEY = "" diff --git a/request_llm/bridge_all.py b/request_llm/bridge_all.py index 7f612bb..e7fc830 100644 --- a/request_llm/bridge_all.py +++ b/request_llm/bridge_all.py @@ -68,6 +68,10 @@ get_token_num_gpt35 = lambda txt: len(tokenizer_gpt35.encode(txt, disallowed_spe get_token_num_gpt4 = lambda txt: len(tokenizer_gpt4.encode(txt, disallowed_special=())) +# 开始初始化模型 +AVAIL_LLM_MODELS, LLM_MODEL = get_conf("AVAIL_LLM_MODELS", "LLM_MODEL") +AVAIL_LLM_MODELS = AVAIL_LLM_MODELS + [LLM_MODEL] +# -=-=-=-=-=-=- 以下这部分是最早加入的最稳定的模型 -=-=-=-=-=-=- model_info = { # openai "gpt-3.5-turbo": { @@ -164,9 +168,7 @@ model_info = { } - -AVAIL_LLM_MODELS, LLM_MODEL = get_conf("AVAIL_LLM_MODELS", "LLM_MODEL") -AVAIL_LLM_MODELS = AVAIL_LLM_MODELS + [LLM_MODEL] +# -=-=-=-=-=-=- 以下部分是新加入的模型,可能附带额外依赖 -=-=-=-=-=-=- if "claude-1-100k" in AVAIL_LLM_MODELS or "claude-2" in AVAIL_LLM_MODELS: from .bridge_claude import predict_no_ui_long_connection as claude_noui from .bridge_claude import predict as claude_ui @@ -367,6 +369,24 @@ if "chatgpt_website" in AVAIL_LLM_MODELS: # 接入一些逆向工程https://gi }) except: print(trimmed_format_exc()) +if "spark" in AVAIL_LLM_MODELS: # 接入一些逆向工程https://github.com/acheong08/ChatGPT-to-API/ + try: + from .bridge_spark import predict_no_ui_long_connection as spark_noui + from .bridge_spark import predict as spark_ui + model_info.update({ + "spark": { + "fn_with_ui": spark_ui, + "fn_without_ui": spark_noui, + "endpoint": None, + "max_token": 4096, + "tokenizer": tokenizer_gpt35, + "token_cnt": get_token_num_gpt35, + } + }) + except: + print(trimmed_format_exc()) + + def LLM_CATCH_EXCEPTION(f): """ diff --git a/request_llm/bridge_spark.py b/request_llm/bridge_spark.py new file mode 100644 index 0000000..551b6f3 --- /dev/null +++ b/request_llm/bridge_spark.py @@ -0,0 +1,49 @@ + +import time +import threading +import importlib +from toolbox import update_ui, get_conf +from multiprocessing import Process, Pipe + +model_name = '星火认知大模型' + +def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False): + """ + ⭐多线程方法 + 函数的说明请见 request_llm/bridge_all.py + """ + watch_dog_patience = 5 + response = "" + + from .com_sparkapi import SparkRequestInstance + sri = SparkRequestInstance() + for response in sri.generate(inputs, llm_kwargs, history, sys_prompt): + if len(observe_window) >= 1: + observe_window[0] = response + if len(observe_window) >= 2: + if (time.time()-observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。") + return response + +def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None): + """ + ⭐单线程方法 + 函数的说明请见 request_llm/bridge_all.py + """ + chatbot.append((inputs, "")) + + if additional_fn is not None: + from core_functional import handle_core_functionality + inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot) + + # 开始接收回复 + from .com_sparkapi import SparkRequestInstance + sri = SparkRequestInstance() + for response in sri.generate(inputs, llm_kwargs, history, system_prompt): + chatbot[-1] = (inputs, response) + yield from update_ui(chatbot=chatbot, history=history) + + # 总结输出 + if response == f"[Local Message]: 等待{model_name}响应中 ...": + response = f"[Local Message]: {model_name}响应异常 ..." + history.extend([inputs, response]) + yield from update_ui(chatbot=chatbot, history=history) \ No newline at end of file diff --git a/request_llm/com_sparkapi.py b/request_llm/com_sparkapi.py index bce39c6..07f2853 100644 --- a/request_llm/com_sparkapi.py +++ b/request_llm/com_sparkapi.py @@ -1,4 +1,4 @@ -import _thread as thread +from toolbox import get_conf import base64 import datetime import hashlib @@ -10,8 +10,8 @@ from datetime import datetime from time import mktime from urllib.parse import urlencode from wsgiref.handlers import format_date_time - import websocket +import threading, time timeout_bot_msg = '[Local Message] Request timeout. Network error.' @@ -37,13 +37,9 @@ class Ws_Param(object): signature_origin += "GET " + self.path + " HTTP/1.1" # 进行hmac-sha256进行加密 - signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - + signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest() signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') - authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') # 将请求的鉴权参数组合为字典 @@ -58,18 +54,84 @@ class Ws_Param(object): return url -# 收到websocket错误的处理 -def on_error(ws, error): - print("### error:", error) + +class SparkRequestInstance(): + def __init__(self): + XFYUN_APPID, XFYUN_API_SECRET, XFYUN_API_KEY = get_conf('XFYUN_APPID', 'XFYUN_API_SECRET', 'XFYUN_API_KEY') + + self.appid = XFYUN_APPID + self.api_secret = XFYUN_API_SECRET + self.api_key = XFYUN_API_KEY + self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat" + self.time_to_yield_event = threading.Event() + self.time_to_exit_event = threading.Event() + + self.result_buf = "" + + def generate(self, inputs, llm_kwargs, history, system_prompt): + llm_kwargs = llm_kwargs + history = history + system_prompt = system_prompt + import _thread as thread + thread.start_new_thread(self.create_blocking_request, (inputs, llm_kwargs, history, system_prompt)) + while True: + self.time_to_yield_event.wait(timeout=1) + if self.time_to_yield_event.is_set(): + yield self.result_buf + if self.time_to_exit_event.is_set(): + return self.result_buf -# 收到websocket关闭的处理 -def on_close(ws): - print("### closed ###") + def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt): + wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.gpt_url) + websocket.enableTrace(False) + wsUrl = wsParam.create_url() + # 收到websocket连接建立的处理 + def on_open(ws): + import _thread as thread + thread.start_new_thread(run, (ws,)) + def run(ws, *args): + data = json.dumps(gen_params(ws.appid, *ws.all_args)) + ws.send(data) -def generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream): + # 收到websocket消息的处理 + def on_message(ws, message): + data = json.loads(message) + code = data['header']['code'] + if code != 0: + print(f'请求错误: {code}, {data}') + ws.close() + self.time_to_exit_event.set() + else: + choices = data["payload"]["choices"] + status = choices["status"] + content = choices["text"][0]["content"] + ws.content += content + self.result_buf += content + if status == 2: + ws.close() + self.time_to_exit_event.set() + self.time_to_yield_event.set() + + # 收到websocket错误的处理 + def on_error(ws, error): + print("error:", error) + self.time_to_exit_event.set() + + # 收到websocket关闭的处理 + def on_close(ws): + self.time_to_exit_event.set() + + # websocket + ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open) + ws.appid = self.appid + ws.content = "" + ws.all_args = (inputs, llm_kwargs, history, system_prompt) + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) + +def generate_message_payload(inputs, llm_kwargs, history, system_prompt): conversation_cnt = len(history) // 2 messages = [{"role": "system", "content": system_prompt}] if conversation_cnt: @@ -94,7 +156,7 @@ def generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream) return messages -def gen_params(appid, inputs, llm_kwargs, history, system_prompt, stream): +def gen_params(appid, inputs, llm_kwargs, history, system_prompt): """ 通过appid和用户的提问来生成请参数 """ @@ -106,75 +168,17 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt, stream): "parameter": { "chat": { "domain": "general", + "temperature": llm_kwargs["temperature"], "random_threshold": 0.5, - "max_tokens": 2048, + "max_tokens": 4096, "auditing": "default" } }, "payload": { "message": { - "text": generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream) + "text": generate_message_payload(inputs, llm_kwargs, history, system_prompt) } } } return data -# 收到websocket消息的处理 -def on_message(ws, message): - print(message) - data = json.loads(message) - code = data['header']['code'] - if code != 0: - print(f'请求错误: {code}, {data}') - ws.close() - else: - choices = data["payload"]["choices"] - status = choices["status"] - content = choices["text"][0]["content"] - ws.content += content - print(content, end='') - if status == 2: - ws.close() - -def commit_request(appid, api_key, api_secret, gpt_url, question): - inputs = question - llm_kwargs = {} - history = [] - system_prompt = "" - stream = True - - wsParam = Ws_Param(appid, api_key, api_secret, gpt_url) - websocket.enableTrace(False) - wsUrl = wsParam.create_url() - - # 收到websocket连接建立的处理 - def on_open(ws): - thread.start_new_thread(run, (ws,)) - - def run(ws, *args): - data = json.dumps(gen_params(ws.appid, *ws.all_args)) - ws.send(data) - - ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open) - ws.appid = appid - ws.content = "" - - ws.all_args = (inputs, llm_kwargs, history, system_prompt, stream) - ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) - - return ws.content - - - -""" - 1、配置好python、pip的环境变量 - 2、执行 pip install websocket 与 pip3 install websocket-client - 3、去控制台https://console.xfyun.cn/services/cbm获取appid等信息填写即可 -""" -if __name__ == "__main__": - # 测试时候在此处正确填写相关信息即可运行 - commit_request(appid="929e7d53", - api_secret="NmFjNGE1ZDNhZWE2MzBkYzg5ZjNkZDcx", - api_key="896f9c5cf1b2b8669f523507f96e6c99", - gpt_url="ws://spark-api.xf-yun.com/v1.1/chat", - question="你是谁?你能做什么") \ No newline at end of file diff --git a/tests/test_llms.py b/tests/test_llms.py index ce52589..75e2303 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -16,10 +16,11 @@ if __name__ == "__main__": # from request_llm.bridge_jittorllms_llama import predict_no_ui_long_connection # from request_llm.bridge_claude import predict_no_ui_long_connection # from request_llm.bridge_internlm import predict_no_ui_long_connection - from request_llm.bridge_qwen import predict_no_ui_long_connection + # from request_llm.bridge_qwen import predict_no_ui_long_connection + from request_llm.bridge_spark import predict_no_ui_long_connection llm_kwargs = { - 'max_length': 512, + 'max_length': 4096, 'top_p': 1, 'temperature': 1, }