diff --git a/config.py b/config.py index f4e1bc8..1b397ba 100644 --- a/config.py +++ b/config.py @@ -34,7 +34,7 @@ WEB_PORT = -1 MAX_RETRY = 2 # OpenAI模型选择是(gpt4现在只对申请成功的人开放) -LLM_MODEL = "gpt-3.5-turbo" +LLM_MODEL = "pygmalion-1.3b@localhost@7860" # "gpt-3.5-turbo" # OpenAI的API_URL API_URL = "https://api.openai.com/v1/chat/completions" diff --git a/predict.py b/predict.py index 31a5861..10e58bb 100644 --- a/predict.py +++ b/predict.py @@ -112,8 +112,7 @@ def predict_no_ui_long_connection(inputs, top_p, temperature, history=[], sys_pr return result -def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', - stream = True, additional_fn=None): +def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', stream = True, additional_fn=None): """ 发送至chatGPT,流式获取输出。 用于基础的对话功能。 @@ -244,3 +243,8 @@ def generate_payload(inputs, top_p, temperature, history, system_prompt, stream) return headers,payload +if not LLM_MODEL.startswith('gpt'): + from request_llm.bridge_tgui import predict_tgui + predict = predict_tgui + + \ No newline at end of file diff --git a/request_llm/bridge_tgui.py b/request_llm/bridge_tgui.py new file mode 100644 index 0000000..37f3826 --- /dev/null +++ b/request_llm/bridge_tgui.py @@ -0,0 +1,137 @@ +''' +Contributed by SagsMug. Modified by binary-husky +https://github.com/oobabooga/text-generation-webui/pull/175 +''' + +import asyncio +import json +import random +import string +import websockets +import logging +import time +import threading +from toolbox import get_conf +LLM_MODEL, = get_conf('LLM_MODEL') + +model_name, addr, port = LLM_MODEL.split('@') + +def random_hash(): + letters = string.ascii_lowercase + string.digits + return ''.join(random.choice(letters) for i in range(9)) + +async def run(context): + params = { + 'max_new_tokens': 200, + 'do_sample': True, + 'temperature': 0.5, + 'top_p': 0.9, + 'typical_p': 1, + 'repetition_penalty': 1.05, + 'encoder_repetition_penalty': 1.0, + 'top_k': 0, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + 'seed': -1, + } + session = random_hash() + + async with websockets.connect(f"ws://{addr}:{port}/queue/join") as websocket: + while content := json.loads(await websocket.recv()): + #Python3.10 syntax, replace with if elif on older + if content["msg"] == "send_hash": + await websocket.send(json.dumps({ + "session_hash": session, + "fn_index": 12 + })) + elif content["msg"] == "estimation": + pass + elif content["msg"] == "send_data": + await websocket.send(json.dumps({ + "session_hash": session, + "fn_index": 12, + "data": [ + context, + params['max_new_tokens'], + params['do_sample'], + params['temperature'], + params['top_p'], + params['typical_p'], + params['repetition_penalty'], + params['encoder_repetition_penalty'], + params['top_k'], + params['min_length'], + params['no_repeat_ngram_size'], + params['num_beams'], + params['penalty_alpha'], + params['length_penalty'], + params['early_stopping'], + params['seed'], + ] + })) + elif content["msg"] == "process_starts": + pass + elif content["msg"] in ["process_generating", "process_completed"]: + yield content["output"]["data"][0] + # You can search for your desired end indicator and + # stop generation by closing the websocket here + if (content["msg"] == "process_completed"): + break + + + + + +def predict_tgui(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='', stream = True, additional_fn=None): + """ + 发送至chatGPT,流式获取输出。 + 用于基础的对话功能。 + inputs 是本次问询的输入 + top_p, temperature是chatGPT的内部调优参数 + history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误) + chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容 + additional_fn代表点击的哪个按钮,按钮见functional.py + """ + if additional_fn is not None: + import functional + importlib.reload(functional) # 热更新prompt + functional = functional.get_functionals() + if "PreProcess" in functional[additional_fn]: inputs = functional[additional_fn]["PreProcess"](inputs) # 获取预处理函数(如果有的话) + inputs = functional[additional_fn]["Prefix"] + inputs + functional[additional_fn]["Suffix"] + + raw_input = inputs + logging.info(f'[raw_input] {raw_input}') + chatbot.append((inputs, "")) + yield chatbot, history, "等待响应" + + prompt = inputs + tgui_say = "" + + mutable = [""] + def run_coorotine(mutable): + async def get_result(): + async for response in run(prompt): + # Print intermediate steps + mutable += response + asyncio.run(get_result()) + + thread_listen = threading.Thread(target=run_coorotine, args=(mutable,)) + thread_listen.start() + + while thread_listen.is_alive(): + time.sleep(1) + # Print intermediate steps + if tgui_say != mutable[0]: + tgui_say = mutable[0] + history[-1] = tgui_say + chatbot[-1] = (history[-2], history[-1]) + yield chatbot, history, status_text + + logging.info(f'[response] {tgui_say}') + + + \ No newline at end of file