接入讯飞星火Spark大模型
This commit is contained in:
parent
c0c4834cfc
commit
af23730f8f
@ -71,7 +71,7 @@ MAX_RETRY = 2
|
|||||||
# 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
|
# 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
|
||||||
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
|
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
|
||||||
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss", "newbing", "stack-claude"]
|
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss", "newbing", "stack-claude"]
|
||||||
# P.S. 其他可用的模型还包括 ["qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"]
|
# P.S. 其他可用的模型还包括 ["qwen", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "spark", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"]
|
||||||
|
|
||||||
|
|
||||||
# ChatGLM(2) Finetune Model Path (如果使用ChatGLM2微调模型,需要把"chatglmft"加入AVAIL_LLM_MODELS中)
|
# ChatGLM(2) Finetune Model Path (如果使用ChatGLM2微调模型,需要把"chatglmft"加入AVAIL_LLM_MODELS中)
|
||||||
@ -137,6 +137,13 @@ ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK
|
|||||||
ALIYUN_ACCESSKEY="" # (无需填写)
|
ALIYUN_ACCESSKEY="" # (无需填写)
|
||||||
ALIYUN_SECRET="" # (无需填写)
|
ALIYUN_SECRET="" # (无需填写)
|
||||||
|
|
||||||
|
|
||||||
|
# 接入讯飞星火大模型 https://console.xfyun.cn/services/iat
|
||||||
|
XFYUN_APPID = "00000000"
|
||||||
|
XFYUN_API_SECRET = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
|
||||||
|
XFYUN_API_KEY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||||
|
|
||||||
|
|
||||||
# Claude API KEY
|
# Claude API KEY
|
||||||
ANTHROPIC_API_KEY = ""
|
ANTHROPIC_API_KEY = ""
|
||||||
|
|
||||||
|
@ -68,6 +68,10 @@ get_token_num_gpt35 = lambda txt: len(tokenizer_gpt35.encode(txt, disallowed_spe
|
|||||||
get_token_num_gpt4 = lambda txt: len(tokenizer_gpt4.encode(txt, disallowed_special=()))
|
get_token_num_gpt4 = lambda txt: len(tokenizer_gpt4.encode(txt, disallowed_special=()))
|
||||||
|
|
||||||
|
|
||||||
|
# 开始初始化模型
|
||||||
|
AVAIL_LLM_MODELS, LLM_MODEL = get_conf("AVAIL_LLM_MODELS", "LLM_MODEL")
|
||||||
|
AVAIL_LLM_MODELS = AVAIL_LLM_MODELS + [LLM_MODEL]
|
||||||
|
# -=-=-=-=-=-=- 以下这部分是最早加入的最稳定的模型 -=-=-=-=-=-=-
|
||||||
model_info = {
|
model_info = {
|
||||||
# openai
|
# openai
|
||||||
"gpt-3.5-turbo": {
|
"gpt-3.5-turbo": {
|
||||||
@ -164,9 +168,7 @@ model_info = {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# -=-=-=-=-=-=- 以下部分是新加入的模型,可能附带额外依赖 -=-=-=-=-=-=-
|
||||||
AVAIL_LLM_MODELS, LLM_MODEL = get_conf("AVAIL_LLM_MODELS", "LLM_MODEL")
|
|
||||||
AVAIL_LLM_MODELS = AVAIL_LLM_MODELS + [LLM_MODEL]
|
|
||||||
if "claude-1-100k" in AVAIL_LLM_MODELS or "claude-2" in AVAIL_LLM_MODELS:
|
if "claude-1-100k" in AVAIL_LLM_MODELS or "claude-2" in AVAIL_LLM_MODELS:
|
||||||
from .bridge_claude import predict_no_ui_long_connection as claude_noui
|
from .bridge_claude import predict_no_ui_long_connection as claude_noui
|
||||||
from .bridge_claude import predict as claude_ui
|
from .bridge_claude import predict as claude_ui
|
||||||
@ -367,6 +369,24 @@ if "chatgpt_website" in AVAIL_LLM_MODELS: # 接入一些逆向工程https://gi
|
|||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
print(trimmed_format_exc())
|
||||||
|
if "spark" in AVAIL_LLM_MODELS: # 接入一些逆向工程https://github.com/acheong08/ChatGPT-to-API/
|
||||||
|
try:
|
||||||
|
from .bridge_spark import predict_no_ui_long_connection as spark_noui
|
||||||
|
from .bridge_spark import predict as spark_ui
|
||||||
|
model_info.update({
|
||||||
|
"spark": {
|
||||||
|
"fn_with_ui": spark_ui,
|
||||||
|
"fn_without_ui": spark_noui,
|
||||||
|
"endpoint": None,
|
||||||
|
"max_token": 4096,
|
||||||
|
"tokenizer": tokenizer_gpt35,
|
||||||
|
"token_cnt": get_token_num_gpt35,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except:
|
||||||
|
print(trimmed_format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def LLM_CATCH_EXCEPTION(f):
|
def LLM_CATCH_EXCEPTION(f):
|
||||||
"""
|
"""
|
||||||
|
49
request_llm/bridge_spark.py
Normal file
49
request_llm/bridge_spark.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import importlib
|
||||||
|
from toolbox import update_ui, get_conf
|
||||||
|
from multiprocessing import Process, Pipe
|
||||||
|
|
||||||
|
model_name = '星火认知大模型'
|
||||||
|
|
||||||
|
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
|
||||||
|
"""
|
||||||
|
⭐多线程方法
|
||||||
|
函数的说明请见 request_llm/bridge_all.py
|
||||||
|
"""
|
||||||
|
watch_dog_patience = 5
|
||||||
|
response = ""
|
||||||
|
|
||||||
|
from .com_sparkapi import SparkRequestInstance
|
||||||
|
sri = SparkRequestInstance()
|
||||||
|
for response in sri.generate(inputs, llm_kwargs, history, sys_prompt):
|
||||||
|
if len(observe_window) >= 1:
|
||||||
|
observe_window[0] = response
|
||||||
|
if len(observe_window) >= 2:
|
||||||
|
if (time.time()-observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
|
||||||
|
return response
|
||||||
|
|
||||||
|
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
|
||||||
|
"""
|
||||||
|
⭐单线程方法
|
||||||
|
函数的说明请见 request_llm/bridge_all.py
|
||||||
|
"""
|
||||||
|
chatbot.append((inputs, ""))
|
||||||
|
|
||||||
|
if additional_fn is not None:
|
||||||
|
from core_functional import handle_core_functionality
|
||||||
|
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
||||||
|
|
||||||
|
# 开始接收回复
|
||||||
|
from .com_sparkapi import SparkRequestInstance
|
||||||
|
sri = SparkRequestInstance()
|
||||||
|
for response in sri.generate(inputs, llm_kwargs, history, system_prompt):
|
||||||
|
chatbot[-1] = (inputs, response)
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
# 总结输出
|
||||||
|
if response == f"[Local Message]: 等待{model_name}响应中 ...":
|
||||||
|
response = f"[Local Message]: {model_name}响应异常 ..."
|
||||||
|
history.extend([inputs, response])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
@ -1,4 +1,4 @@
|
|||||||
import _thread as thread
|
from toolbox import get_conf
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -10,8 +10,8 @@ from datetime import datetime
|
|||||||
from time import mktime
|
from time import mktime
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from wsgiref.handlers import format_date_time
|
from wsgiref.handlers import format_date_time
|
||||||
|
|
||||||
import websocket
|
import websocket
|
||||||
|
import threading, time
|
||||||
|
|
||||||
timeout_bot_msg = '[Local Message] Request timeout. Network error.'
|
timeout_bot_msg = '[Local Message] Request timeout. Network error.'
|
||||||
|
|
||||||
@ -37,13 +37,9 @@ class Ws_Param(object):
|
|||||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||||
|
|
||||||
# 进行hmac-sha256进行加密
|
# 进行hmac-sha256进行加密
|
||||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest()
|
||||||
digestmod=hashlib.sha256).digest()
|
|
||||||
|
|
||||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||||
|
|
||||||
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||||
|
|
||||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||||
|
|
||||||
# 将请求的鉴权参数组合为字典
|
# 将请求的鉴权参数组合为字典
|
||||||
@ -58,18 +54,84 @@ class Ws_Param(object):
|
|||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
# 收到websocket错误的处理
|
|
||||||
def on_error(ws, error):
|
class SparkRequestInstance():
|
||||||
print("### error:", error)
|
def __init__(self):
|
||||||
|
XFYUN_APPID, XFYUN_API_SECRET, XFYUN_API_KEY = get_conf('XFYUN_APPID', 'XFYUN_API_SECRET', 'XFYUN_API_KEY')
|
||||||
|
|
||||||
|
self.appid = XFYUN_APPID
|
||||||
|
self.api_secret = XFYUN_API_SECRET
|
||||||
|
self.api_key = XFYUN_API_KEY
|
||||||
|
self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||||
|
self.time_to_yield_event = threading.Event()
|
||||||
|
self.time_to_exit_event = threading.Event()
|
||||||
|
|
||||||
|
self.result_buf = ""
|
||||||
|
|
||||||
|
def generate(self, inputs, llm_kwargs, history, system_prompt):
|
||||||
|
llm_kwargs = llm_kwargs
|
||||||
|
history = history
|
||||||
|
system_prompt = system_prompt
|
||||||
|
import _thread as thread
|
||||||
|
thread.start_new_thread(self.create_blocking_request, (inputs, llm_kwargs, history, system_prompt))
|
||||||
|
while True:
|
||||||
|
self.time_to_yield_event.wait(timeout=1)
|
||||||
|
if self.time_to_yield_event.is_set():
|
||||||
|
yield self.result_buf
|
||||||
|
if self.time_to_exit_event.is_set():
|
||||||
|
return self.result_buf
|
||||||
|
|
||||||
|
|
||||||
# 收到websocket关闭的处理
|
def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt):
|
||||||
def on_close(ws):
|
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, self.gpt_url)
|
||||||
print("### closed ###")
|
websocket.enableTrace(False)
|
||||||
|
wsUrl = wsParam.create_url()
|
||||||
|
|
||||||
|
# 收到websocket连接建立的处理
|
||||||
|
def on_open(ws):
|
||||||
|
import _thread as thread
|
||||||
|
thread.start_new_thread(run, (ws,))
|
||||||
|
|
||||||
|
def run(ws, *args):
|
||||||
|
data = json.dumps(gen_params(ws.appid, *ws.all_args))
|
||||||
|
ws.send(data)
|
||||||
|
|
||||||
def generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream):
|
# 收到websocket消息的处理
|
||||||
|
def on_message(ws, message):
|
||||||
|
data = json.loads(message)
|
||||||
|
code = data['header']['code']
|
||||||
|
if code != 0:
|
||||||
|
print(f'请求错误: {code}, {data}')
|
||||||
|
ws.close()
|
||||||
|
self.time_to_exit_event.set()
|
||||||
|
else:
|
||||||
|
choices = data["payload"]["choices"]
|
||||||
|
status = choices["status"]
|
||||||
|
content = choices["text"][0]["content"]
|
||||||
|
ws.content += content
|
||||||
|
self.result_buf += content
|
||||||
|
if status == 2:
|
||||||
|
ws.close()
|
||||||
|
self.time_to_exit_event.set()
|
||||||
|
self.time_to_yield_event.set()
|
||||||
|
|
||||||
|
# 收到websocket错误的处理
|
||||||
|
def on_error(ws, error):
|
||||||
|
print("error:", error)
|
||||||
|
self.time_to_exit_event.set()
|
||||||
|
|
||||||
|
# 收到websocket关闭的处理
|
||||||
|
def on_close(ws):
|
||||||
|
self.time_to_exit_event.set()
|
||||||
|
|
||||||
|
# websocket
|
||||||
|
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
||||||
|
ws.appid = self.appid
|
||||||
|
ws.content = ""
|
||||||
|
ws.all_args = (inputs, llm_kwargs, history, system_prompt)
|
||||||
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||||
|
|
||||||
|
def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
|
||||||
conversation_cnt = len(history) // 2
|
conversation_cnt = len(history) // 2
|
||||||
messages = [{"role": "system", "content": system_prompt}]
|
messages = [{"role": "system", "content": system_prompt}]
|
||||||
if conversation_cnt:
|
if conversation_cnt:
|
||||||
@ -94,7 +156,7 @@ def generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream)
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def gen_params(appid, inputs, llm_kwargs, history, system_prompt, stream):
|
def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
|
||||||
"""
|
"""
|
||||||
通过appid和用户的提问来生成请参数
|
通过appid和用户的提问来生成请参数
|
||||||
"""
|
"""
|
||||||
@ -106,75 +168,17 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt, stream):
|
|||||||
"parameter": {
|
"parameter": {
|
||||||
"chat": {
|
"chat": {
|
||||||
"domain": "general",
|
"domain": "general",
|
||||||
|
"temperature": llm_kwargs["temperature"],
|
||||||
"random_threshold": 0.5,
|
"random_threshold": 0.5,
|
||||||
"max_tokens": 2048,
|
"max_tokens": 4096,
|
||||||
"auditing": "default"
|
"auditing": "default"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
"message": {
|
"message": {
|
||||||
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt, stream)
|
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
# 收到websocket消息的处理
|
|
||||||
def on_message(ws, message):
|
|
||||||
print(message)
|
|
||||||
data = json.loads(message)
|
|
||||||
code = data['header']['code']
|
|
||||||
if code != 0:
|
|
||||||
print(f'请求错误: {code}, {data}')
|
|
||||||
ws.close()
|
|
||||||
else:
|
|
||||||
choices = data["payload"]["choices"]
|
|
||||||
status = choices["status"]
|
|
||||||
content = choices["text"][0]["content"]
|
|
||||||
ws.content += content
|
|
||||||
print(content, end='')
|
|
||||||
if status == 2:
|
|
||||||
ws.close()
|
|
||||||
|
|
||||||
def commit_request(appid, api_key, api_secret, gpt_url, question):
|
|
||||||
inputs = question
|
|
||||||
llm_kwargs = {}
|
|
||||||
history = []
|
|
||||||
system_prompt = ""
|
|
||||||
stream = True
|
|
||||||
|
|
||||||
wsParam = Ws_Param(appid, api_key, api_secret, gpt_url)
|
|
||||||
websocket.enableTrace(False)
|
|
||||||
wsUrl = wsParam.create_url()
|
|
||||||
|
|
||||||
# 收到websocket连接建立的处理
|
|
||||||
def on_open(ws):
|
|
||||||
thread.start_new_thread(run, (ws,))
|
|
||||||
|
|
||||||
def run(ws, *args):
|
|
||||||
data = json.dumps(gen_params(ws.appid, *ws.all_args))
|
|
||||||
ws.send(data)
|
|
||||||
|
|
||||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
|
||||||
ws.appid = appid
|
|
||||||
ws.content = ""
|
|
||||||
|
|
||||||
ws.all_args = (inputs, llm_kwargs, history, system_prompt, stream)
|
|
||||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
|
||||||
|
|
||||||
return ws.content
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
1、配置好python、pip的环境变量
|
|
||||||
2、执行 pip install websocket 与 pip3 install websocket-client
|
|
||||||
3、去控制台https://console.xfyun.cn/services/cbm获取appid等信息填写即可
|
|
||||||
"""
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 测试时候在此处正确填写相关信息即可运行
|
|
||||||
commit_request(appid="929e7d53",
|
|
||||||
api_secret="NmFjNGE1ZDNhZWE2MzBkYzg5ZjNkZDcx",
|
|
||||||
api_key="896f9c5cf1b2b8669f523507f96e6c99",
|
|
||||||
gpt_url="ws://spark-api.xf-yun.com/v1.1/chat",
|
|
||||||
question="你是谁?你能做什么")
|
|
@ -16,10 +16,11 @@ if __name__ == "__main__":
|
|||||||
# from request_llm.bridge_jittorllms_llama import predict_no_ui_long_connection
|
# from request_llm.bridge_jittorllms_llama import predict_no_ui_long_connection
|
||||||
# from request_llm.bridge_claude import predict_no_ui_long_connection
|
# from request_llm.bridge_claude import predict_no_ui_long_connection
|
||||||
# from request_llm.bridge_internlm import predict_no_ui_long_connection
|
# from request_llm.bridge_internlm import predict_no_ui_long_connection
|
||||||
from request_llm.bridge_qwen import predict_no_ui_long_connection
|
# from request_llm.bridge_qwen import predict_no_ui_long_connection
|
||||||
|
from request_llm.bridge_spark import predict_no_ui_long_connection
|
||||||
|
|
||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
'max_length': 512,
|
'max_length': 4096,
|
||||||
'top_p': 1,
|
'top_p': 1,
|
||||||
'temperature': 1,
|
'temperature': 1,
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user