接入讯飞星火Spark大模型

This commit is contained in:
binary-husky 2023-08-14 03:08:15 +08:00
parent c0c4834cfc
commit af23730f8f
5 changed files with 163 additions and 82 deletions

View File

@ -71,7 +71,7 @@ MAX_RETRY = 2
# 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss", "newbing", "stack-claude"]
# P.S. 其他可用的模型还包括 ["qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"]
# P.S. 其他可用的模型还包括 ["qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "spark", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"]
# ChatGLM(2) Finetune Model Path 如果使用ChatGLM2微调模型需要把"chatglmft"加入AVAIL_LLM_MODELS中
@ -137,6 +137,13 @@ ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK
ALIYUN_ACCESSKEY="" # (无需填写)
ALIYUN_SECRET="" # (无需填写)
# 接入讯飞星火大模型 https://console.xfyun.cn/services/iat
XFYUN_APPID = "00000000"
XFYUN_API_SECRET = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
XFYUN_API_KEY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
# Claude API KEY
ANTHROPIC_API_KEY = ""

View File

@ -68,6 +68,10 @@ get_token_num_gpt35 = lambda txt: len(tokenizer_gpt35.encode(txt, disallowed_spe
get_token_num_gpt4 = lambda txt: len(tokenizer_gpt4.encode(txt, disallowed_special=()))
# 开始初始化模型
AVAIL_LLM_MODELS, LLM_MODEL = get_conf("AVAIL_LLM_MODELS", "LLM_MODEL")
AVAIL_LLM_MODELS = AVAIL_LLM_MODELS + [LLM_MODEL]
# -=-=-=-=-=-=- 以下这部分是最早加入的最稳定的模型 -=-=-=-=-=-=-
model_info = {
# openai
"gpt-3.5-turbo": {
@ -164,9 +168,7 @@ model_info = {
}
AVAIL_LLM_MODELS, LLM_MODEL = get_conf("AVAIL_LLM_MODELS", "LLM_MODEL")
AVAIL_LLM_MODELS = AVAIL_LLM_MODELS + [LLM_MODEL]
# -=-=-=-=-=-=- 以下部分是新加入的模型,可能附带额外依赖 -=-=-=-=-=-=-
if "claude-1-100k" in AVAIL_LLM_MODELS or "claude-2" in AVAIL_LLM_MODELS:
from .bridge_claude import predict_no_ui_long_connection as claude_noui
from .bridge_claude import predict as claude_ui
@ -367,6 +369,24 @@ if "chatgpt_website" in AVAIL_LLM_MODELS: # 接入一些逆向工程https://gi
})
except:
print(trimmed_format_exc())
if "spark" in AVAIL_LLM_MODELS: # 接入一些逆向工程https://github.com/acheong08/ChatGPT-to-API/
try:
from .bridge_spark import predict_no_ui_long_connection as spark_noui
from .bridge_spark import predict as spark_ui
model_info.update({
"spark": {
"fn_with_ui": spark_ui,
"fn_without_ui": spark_noui,
"endpoint": None,
"max_token": 4096,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
}
})
except:
print(trimmed_format_exc())
def LLM_CATCH_EXCEPTION(f):
"""

View File

@ -0,0 +1,49 @@
import time
import threading
import importlib
from toolbox import update_ui, get_conf
from multiprocessing import Process, Pipe
model_name = '星火认知大模型'
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
"""
多线程方法
函数的说明请见 request_llm/bridge_all.py
"""
watch_dog_patience = 5
response = ""
from .com_sparkapi import SparkRequestInstance
sri = SparkRequestInstance()
for response in sri.generate(inputs, llm_kwargs, history, sys_prompt):
if len(observe_window) >= 1:
observe_window[0] = response
if len(observe_window) >= 2:
if (time.time()-observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
return response
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
"""
单线程方法
函数的说明请见 request_llm/bridge_all.py
"""
chatbot.append((inputs, ""))
if additional_fn is not None:
from core_functional import handle_core_functionality
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
# 开始接收回复
from .com_sparkapi import SparkRequestInstance
sri = SparkRequestInstance()
for response in sri.generate(inputs, llm_kwargs, history, system_prompt):
chatbot[-1] = (inputs, response)
yield from update_ui(chatbot=chatbot, history=history)
# 总结输出
if response == f"[Local Message]: 等待{model_name}响应中 ...":
response = f"[Local Message]: {model_name}响应异常 ..."
history.extend([inputs, response])
yield from update_ui(chatbot=chatbot, history=history)

View File

@ -1,4 +1,4 @@
import _thread as thread
from toolbox import get_conf
import base64
import datetime
import hashlib
@ -10,8 +10,8 @@ from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket
import threading, time
timeout_bot_msg = '[Local Message] Request timeout. Network error.'
@ -37,13 +37,9 @@ class Ws_Param(object):
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
@ -58,18 +54,84 @@ class Ws_Param(object):
return url
# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)
class SparkRequestInstance():
def __init__(self):
XFYUN_APPID, XFYUN_API_SECRET, XFYUN_API_KEY = get_conf('XFYUN_APPID', 'XFYUN_API_SECRET', 'XFYUN_API_KEY')
self.appid = XFYUN_APPID
self.api_secret = XFYUN_API_SECRET
self.api_key = XFYUN_API_KEY
self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"
self.time_to_yield_event = threading.Event()
self.time_to_exit_event = threading.Event()
self.result_buf = ""
def generate(self, inputs, llm_kwargs, history, system_prompt):
llm_kwargs = llm_kwargs
history = history
system_prompt = system_prompt
import _thread as thread
thread.start_new_thread(self.create_blocking_request, (inputs, llm_kwargs, history, system_prompt))
while True:
self.time_to_yield_event.wait(timeout=1)
if self.time_to_yield_event.is_set():
yield self.result_buf
if self.time_to_exit_event.is_set():
return self.result_buf
# 收到websocket关闭的处理
def on_close(ws):
print("### closed ###")
def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt):
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.gpt_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
# 收到websocket连接建立的处理
def on_open(ws):
import _thread as thread
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(ws.appid, *ws.all_args))
ws.send(data)
def generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream):
# 收到websocket消息的处理
def on_message(ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
self.time_to_exit_event.set()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
ws.content += content
self.result_buf += content
if status == 2:
ws.close()
self.time_to_exit_event.set()
self.time_to_yield_event.set()
# 收到websocket错误的处理
def on_error(ws, error):
print("error:", error)
self.time_to_exit_event.set()
# 收到websocket关闭的处理
def on_close(ws):
self.time_to_exit_event.set()
# websocket
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = self.appid
ws.content = ""
ws.all_args = (inputs, llm_kwargs, history, system_prompt)
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
conversation_cnt = len(history) // 2
messages = [{"role": "system", "content": system_prompt}]
if conversation_cnt:
@ -94,7 +156,7 @@ def generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream)
return messages
def gen_params(appid, inputs, llm_kwargs, history, system_prompt, stream):
def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
"""
通过appid和用户的提问来生成请参数
"""
@ -106,75 +168,17 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt, stream):
"parameter": {
"chat": {
"domain": "general",
"temperature": llm_kwargs["temperature"],
"random_threshold": 0.5,
"max_tokens": 2048,
"max_tokens": 4096,
"auditing": "default"
}
},
"payload": {
"message": {
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream)
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt)
}
}
}
return data
# 收到websocket消息的处理
def on_message(ws, message):
print(message)
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
ws.content += content
print(content, end='')
if status == 2:
ws.close()
def commit_request(appid, api_key, api_secret, gpt_url, question):
inputs = question
llm_kwargs = {}
history = []
system_prompt = ""
stream = True
wsParam = Ws_Param(appid, api_key, api_secret, gpt_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
# 收到websocket连接建立的处理
def on_open(ws):
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(ws.appid, *ws.all_args))
ws.send(data)
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = appid
ws.content = ""
ws.all_args = (inputs, llm_kwargs, history, system_prompt, stream)
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
return ws.content
"""
1配置好pythonpip的环境变量
2执行 pip install websocket pip3 install websocket-client
3去控制台https://console.xfyun.cn/services/cbm获取appid等信息填写即可
"""
if __name__ == "__main__":
# 测试时候在此处正确填写相关信息即可运行
commit_request(appid="929e7d53",
api_secret="NmFjNGE1ZDNhZWE2MzBkYzg5ZjNkZDcx",
api_key="896f9c5cf1b2b8669f523507f96e6c99",
gpt_url="ws://spark-api.xf-yun.com/v1.1/chat",
question="你是谁?你能做什么")

View File

@ -16,10 +16,11 @@ if __name__ == "__main__":
# from request_llm.bridge_jittorllms_llama import predict_no_ui_long_connection
# from request_llm.bridge_claude import predict_no_ui_long_connection
# from request_llm.bridge_internlm import predict_no_ui_long_connection
from request_llm.bridge_qwen import predict_no_ui_long_connection
# from request_llm.bridge_qwen import predict_no_ui_long_connection
from request_llm.bridge_spark import predict_no_ui_long_connection
llm_kwargs = {
'max_length': 512,
'max_length': 4096,
'top_p': 1,
'temperature': 1,
}