diff --git a/config.py b/config.py index 803129f..ad9cfd9 100644 --- a/config.py +++ b/config.py @@ -34,7 +34,7 @@ WEB_PORT = -1 MAX_RETRY = 2 # OpenAI模型选择是(gpt4现在只对申请成功的人开放) -LLM_MODEL = "pygmalion-1.3b@localhost@7860" # "gpt-3.5-turbo" +LLM_MODEL = "TGUI:galactica-1.3b@localhost:7860" # "gpt-3.5-turbo" # OpenAI的API_URL API_URL = "https://api.openai.com/v1/chat/completions" diff --git a/request_llm/README.md b/request_llm/README.md index 26f0dde..a539f1f 100644 --- a/request_llm/README.md +++ b/request_llm/README.md @@ -2,7 +2,7 @@ ## 1. 先运行text-generation ``` sh -# 下载模型 +# 下载模型( text-generation 这么牛的项目,别忘了给人家star ) git clone https://github.com/oobabooga/text-generation-webui.git # 安装text-generation的额外依赖 @@ -12,28 +12,25 @@ pip install accelerate bitsandbytes flexgen gradio llamacpp markdown numpy peft cd text-generation-webui # 下载模型 -python download-model.py facebook/opt-1.3b - -# 其他可选如 facebook/galactica-1.3b +python download-model.py facebook/galactica-1.3b +# 其他可选如 facebook/opt-1.3b # facebook/galactica-6.7b # facebook/galactica-120b - -# Pymalion 6B is a proof-of-concept dialogue model based on EleutherAI's GPT-J-6B. -# facebook/pygmalion-1.3b +# facebook/pygmalion-1.3b 等 +# 详情见 https://github.com/oobabooga/text-generation-webui # 启动text-generation,注意把模型的斜杠改成下划线 python server.py --cpu --listen --listen-port 7860 --model facebook_galactica-1.3b ``` ## 2. 修改config.py +``` sh +# LLM_MODEL格式较复杂 TGUI:[模型]@[ws地址]:[ws端口] , 端口要和上面给定的端口一致 +LLM_MODEL = "TGUI:galactica-1.3b@localhost:7860" ``` -# LLM_MODEL格式为 [模型]@[ws地址] @[ws端口] -LLM_MODEL = "pygmalion-1.3b@localhost@7860" -``` - ## 3. 运行! -``` +``` sh cd chatgpt-academic python main.py ``` diff --git a/request_llm/bridge_tgui.py b/request_llm/bridge_tgui.py index 1c7103f..916416b 100644 --- a/request_llm/bridge_tgui.py +++ b/request_llm/bridge_tgui.py @@ -15,7 +15,10 @@ import importlib from toolbox import get_conf LLM_MODEL, = get_conf('LLM_MODEL') -model_name, addr, port = LLM_MODEL.split('@') +# "TGUI:galactica-1.3b@localhost:7860" +model_name, addr_port = LLM_MODEL.split('@') +assert ':' in addr_port, "LLM_MODEL 格式不正确!" + LLM_MODEL +addr, port = addr_port.split(':') def random_hash(): letters = string.ascii_lowercase + string.digits @@ -117,11 +120,11 @@ def predict_tgui(inputs, top_p, temperature, chatbot=[], history=[], system_prom def run_coorotine(mutable): async def get_result(mutable): async for response in run(prompt): - # Print intermediate steps + print(response[len(mutable[0]):]) mutable[0] = response asyncio.run(get_result(mutable)) - thread_listen = threading.Thread(target=run_coorotine, args=(mutable,)) + thread_listen = threading.Thread(target=run_coorotine, args=(mutable,), daemon=True) thread_listen.start() while thread_listen.is_alive(): @@ -145,7 +148,7 @@ def predict_tgui_no_ui(inputs, top_p, temperature, history=[], sys_prompt=""): def run_coorotine(mutable): async def get_result(mutable): async for response in run(prompt): - # Print intermediate steps + print(response[len(mutable[0]):]) mutable[0] = response asyncio.run(get_result(mutable)) thread_listen = threading.Thread(target=run_coorotine, args=(mutable,))