接入讯飞星火Spark大模型
This commit is contained in:
parent
c0c4834cfc
commit
af23730f8f
@ -71,7 +71,7 @@ MAX_RETRY = 2
|
||||
# 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
|
||||
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
|
||||
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss", "newbing", "stack-claude"]
|
||||
# P.S. 其他可用的模型还包括 ["qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"]
|
||||
# P.S. 其他可用的模型还包括 ["qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "spark", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"]
|
||||
|
||||
|
||||
# ChatGLM(2) Finetune Model Path (如果使用ChatGLM2微调模型,需要把"chatglmft"加入AVAIL_LLM_MODELS中)
|
||||
@ -137,6 +137,13 @@ ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK
|
||||
ALIYUN_ACCESSKEY="" # (无需填写)
|
||||
ALIYUN_SECRET="" # (无需填写)
|
||||
|
||||
|
||||
# 接入讯飞星火大模型 https://console.xfyun.cn/services/iat
|
||||
XFYUN_APPID = "00000000"
|
||||
XFYUN_API_SECRET = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
|
||||
XFYUN_API_KEY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
|
||||
|
||||
# Claude API KEY
|
||||
ANTHROPIC_API_KEY = ""
|
||||
|
||||
|
@ -68,6 +68,10 @@ get_token_num_gpt35 = lambda txt: len(tokenizer_gpt35.encode(txt, disallowed_spe
|
||||
get_token_num_gpt4 = lambda txt: len(tokenizer_gpt4.encode(txt, disallowed_special=()))
|
||||
|
||||
|
||||
# 开始初始化模型
|
||||
AVAIL_LLM_MODELS, LLM_MODEL = get_conf("AVAIL_LLM_MODELS", "LLM_MODEL")
|
||||
AVAIL_LLM_MODELS = AVAIL_LLM_MODELS + [LLM_MODEL]
|
||||
# -=-=-=-=-=-=- 以下这部分是最早加入的最稳定的模型 -=-=-=-=-=-=-
|
||||
model_info = {
|
||||
# openai
|
||||
"gpt-3.5-turbo": {
|
||||
@ -164,9 +168,7 @@ model_info = {
|
||||
|
||||
}
|
||||
|
||||
|
||||
AVAIL_LLM_MODELS, LLM_MODEL = get_conf("AVAIL_LLM_MODELS", "LLM_MODEL")
|
||||
AVAIL_LLM_MODELS = AVAIL_LLM_MODELS + [LLM_MODEL]
|
||||
# -=-=-=-=-=-=- 以下部分是新加入的模型,可能附带额外依赖 -=-=-=-=-=-=-
|
||||
if "claude-1-100k" in AVAIL_LLM_MODELS or "claude-2" in AVAIL_LLM_MODELS:
|
||||
from .bridge_claude import predict_no_ui_long_connection as claude_noui
|
||||
from .bridge_claude import predict as claude_ui
|
||||
@ -367,6 +369,24 @@ if "chatgpt_website" in AVAIL_LLM_MODELS: # 接入一些逆向工程https://gi
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
if "spark" in AVAIL_LLM_MODELS: # 接入一些逆向工程https://github.com/acheong08/ChatGPT-to-API/
|
||||
try:
|
||||
from .bridge_spark import predict_no_ui_long_connection as spark_noui
|
||||
from .bridge_spark import predict as spark_ui
|
||||
model_info.update({
|
||||
"spark": {
|
||||
"fn_with_ui": spark_ui,
|
||||
"fn_without_ui": spark_noui,
|
||||
"endpoint": None,
|
||||
"max_token": 4096,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
}
|
||||
})
|
||||
except:
|
||||
print(trimmed_format_exc())
|
||||
|
||||
|
||||
|
||||
def LLM_CATCH_EXCEPTION(f):
|
||||
"""
|
||||
|
49
request_llm/bridge_spark.py
Normal file
49
request_llm/bridge_spark.py
Normal file
@ -0,0 +1,49 @@
|
||||
|
||||
import time
|
||||
import threading
|
||||
import importlib
|
||||
from toolbox import update_ui, get_conf
|
||||
from multiprocessing import Process, Pipe
|
||||
|
||||
model_name = '星火认知大模型'
|
||||
|
||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
|
||||
"""
|
||||
⭐多线程方法
|
||||
函数的说明请见 request_llm/bridge_all.py
|
||||
"""
|
||||
watch_dog_patience = 5
|
||||
response = ""
|
||||
|
||||
from .com_sparkapi import SparkRequestInstance
|
||||
sri = SparkRequestInstance()
|
||||
for response in sri.generate(inputs, llm_kwargs, history, sys_prompt):
|
||||
if len(observe_window) >= 1:
|
||||
observe_window[0] = response
|
||||
if len(observe_window) >= 2:
|
||||
if (time.time()-observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
|
||||
return response
|
||||
|
||||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
|
||||
"""
|
||||
⭐单线程方法
|
||||
函数的说明请见 request_llm/bridge_all.py
|
||||
"""
|
||||
chatbot.append((inputs, ""))
|
||||
|
||||
if additional_fn is not None:
|
||||
from core_functional import handle_core_functionality
|
||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||
|
||||
# 开始接收回复
|
||||
from .com_sparkapi import SparkRequestInstance
|
||||
sri = SparkRequestInstance()
|
||||
for response in sri.generate(inputs, llm_kwargs, history, system_prompt):
|
||||
chatbot[-1] = (inputs, response)
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
# 总结输出
|
||||
if response == f"[Local Message]: 等待{model_name}响应中 ...":
|
||||
response = f"[Local Message]: {model_name}响应异常 ..."
|
||||
history.extend([inputs, response])
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
@ -1,4 +1,4 @@
|
||||
import _thread as thread
|
||||
from toolbox import get_conf
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
@ -10,8 +10,8 @@ from datetime import datetime
|
||||
from time import mktime
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import websocket
|
||||
import threading, time
|
||||
|
||||
timeout_bot_msg = '[Local Message] Request timeout. Network error.'
|
||||
|
||||
@ -37,13 +37,9 @@ class Ws_Param(object):
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest()
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
@ -58,18 +54,84 @@ class Ws_Param(object):
|
||||
return url
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("### error:", error)
|
||||
|
||||
class SparkRequestInstance():
|
||||
def __init__(self):
|
||||
XFYUN_APPID, XFYUN_API_SECRET, XFYUN_API_KEY = get_conf('XFYUN_APPID', 'XFYUN_API_SECRET', 'XFYUN_API_KEY')
|
||||
|
||||
self.appid = XFYUN_APPID
|
||||
self.api_secret = XFYUN_API_SECRET
|
||||
self.api_key = XFYUN_API_KEY
|
||||
self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
self.time_to_yield_event = threading.Event()
|
||||
self.time_to_exit_event = threading.Event()
|
||||
|
||||
self.result_buf = ""
|
||||
|
||||
def generate(self, inputs, llm_kwargs, history, system_prompt):
|
||||
llm_kwargs = llm_kwargs
|
||||
history = history
|
||||
system_prompt = system_prompt
|
||||
import _thread as thread
|
||||
thread.start_new_thread(self.create_blocking_request, (inputs, llm_kwargs, history, system_prompt))
|
||||
while True:
|
||||
self.time_to_yield_event.wait(timeout=1)
|
||||
if self.time_to_yield_event.is_set():
|
||||
yield self.result_buf
|
||||
if self.time_to_exit_event.is_set():
|
||||
return self.result_buf
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws):
|
||||
print("### closed ###")
|
||||
def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt):
|
||||
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.gpt_url)
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
import _thread as thread
|
||||
thread.start_new_thread(run, (ws,))
|
||||
|
||||
def run(ws, *args):
|
||||
data = json.dumps(gen_params(ws.appid, *ws.all_args))
|
||||
ws.send(data)
|
||||
|
||||
def generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream):
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
print(f'请求错误: {code}, {data}')
|
||||
ws.close()
|
||||
self.time_to_exit_event.set()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
ws.content += content
|
||||
self.result_buf += content
|
||||
if status == 2:
|
||||
ws.close()
|
||||
self.time_to_exit_event.set()
|
||||
self.time_to_yield_event.set()
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("error:", error)
|
||||
self.time_to_exit_event.set()
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws):
|
||||
self.time_to_exit_event.set()
|
||||
|
||||
# websocket
|
||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
||||
ws.appid = self.appid
|
||||
ws.content = ""
|
||||
ws.all_args = (inputs, llm_kwargs, history, system_prompt)
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
|
||||
conversation_cnt = len(history) // 2
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if conversation_cnt:
|
||||
@ -94,7 +156,7 @@ def generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream)
|
||||
return messages
|
||||
|
||||
|
||||
def gen_params(appid, inputs, llm_kwargs, history, system_prompt, stream):
|
||||
def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
@ -106,75 +168,17 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt, stream):
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": "general",
|
||||
"temperature": llm_kwargs["temperature"],
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"max_tokens": 4096,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream)
|
||||
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
print(message)
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
print(f'请求错误: {code}, {data}')
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
ws.content += content
|
||||
print(content, end='')
|
||||
if status == 2:
|
||||
ws.close()
|
||||
|
||||
def commit_request(appid, api_key, api_secret, gpt_url, question):
|
||||
inputs = question
|
||||
llm_kwargs = {}
|
||||
history = []
|
||||
system_prompt = ""
|
||||
stream = True
|
||||
|
||||
wsParam = Ws_Param(appid, api_key, api_secret, gpt_url)
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
thread.start_new_thread(run, (ws,))
|
||||
|
||||
def run(ws, *args):
|
||||
data = json.dumps(gen_params(ws.appid, *ws.all_args))
|
||||
ws.send(data)
|
||||
|
||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
||||
ws.appid = appid
|
||||
ws.content = ""
|
||||
|
||||
ws.all_args = (inputs, llm_kwargs, history, system_prompt, stream)
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
return ws.content
|
||||
|
||||
|
||||
|
||||
"""
|
||||
1、配置好python、pip的环境变量
|
||||
2、执行 pip install websocket 与 pip3 install websocket-client
|
||||
3、去控制台https://console.xfyun.cn/services/cbm获取appid等信息填写即可
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
# 测试时候在此处正确填写相关信息即可运行
|
||||
commit_request(appid="929e7d53",
|
||||
api_secret="NmFjNGE1ZDNhZWE2MzBkYzg5ZjNkZDcx",
|
||||
api_key="896f9c5cf1b2b8669f523507f96e6c99",
|
||||
gpt_url="ws://spark-api.xf-yun.com/v1.1/chat",
|
||||
question="你是谁?你能做什么")
|
@ -16,10 +16,11 @@ if __name__ == "__main__":
|
||||
# from request_llm.bridge_jittorllms_llama import predict_no_ui_long_connection
|
||||
# from request_llm.bridge_claude import predict_no_ui_long_connection
|
||||
# from request_llm.bridge_internlm import predict_no_ui_long_connection
|
||||
from request_llm.bridge_qwen import predict_no_ui_long_connection
|
||||
# from request_llm.bridge_qwen import predict_no_ui_long_connection
|
||||
from request_llm.bridge_spark import predict_no_ui_long_connection
|
||||
|
||||
llm_kwargs = {
|
||||
'max_length': 512,
|
||||
'max_length': 4096,
|
||||
'top_p': 1,
|
||||
'temperature': 1,
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user