接入TGUI

This commit is contained in:
Your Name 2023-04-02 00:40:05 +08:00
parent 3af0bbdbe4
commit 2420d62a33
3 changed files with 17 additions and 17 deletions

View File

@ -34,7 +34,7 @@ WEB_PORT = -1
MAX_RETRY = 2
# OpenAI模型选择是gpt4现在只对申请成功的人开放
LLM_MODEL = "pygmalion-1.3b@localhost@7860" # "gpt-3.5-turbo"
LLM_MODEL = "TGUI:galactica-1.3b@localhost:7860" # "gpt-3.5-turbo"
# OpenAI的API_URL
API_URL = "https://api.openai.com/v1/chat/completions"

View File

@ -2,7 +2,7 @@
## 1. 先运行text-generation
``` sh
# 下载模型
# 下载模型 text-generation 这么牛的项目别忘了给人家star
git clone https://github.com/oobabooga/text-generation-webui.git
# 安装text-generation的额外依赖
@ -12,28 +12,25 @@ pip install accelerate bitsandbytes flexgen gradio llamacpp markdown numpy peft
cd text-generation-webui
# 下载模型
python download-model.py facebook/opt-1.3b
# 其他可选如 facebook/galactica-1.3b
python download-model.py facebook/galactica-1.3b
# 其他可选如 facebook/opt-1.3b
# facebook/galactica-6.7b
# facebook/galactica-120b
# Pymalion 6B is a proof-of-concept dialogue model based on EleutherAI's GPT-J-6B.
# facebook/pygmalion-1.3b
# facebook/pygmalion-1.3b 等
# 详情见 https://github.com/oobabooga/text-generation-webui
# 启动text-generation注意把模型的斜杠改成下划线
python server.py --cpu --listen --listen-port 7860 --model facebook_galactica-1.3b
```
## 2. 修改config.py
``` sh
# LLM_MODEL格式较复杂 TGUI:[模型]@[ws地址]:[ws端口] , 端口要和上面给定的端口一致
LLM_MODEL = "TGUI:galactica-1.3b@localhost:7860"
```
# LLM_MODEL格式为 [模型]@[ws地址] @[ws端口]
LLM_MODEL = "pygmalion-1.3b@localhost@7860"
```
## 3. 运行!
```
``` sh
cd chatgpt-academic
python main.py
```

View File

@ -15,7 +15,10 @@ import importlib
from toolbox import get_conf
LLM_MODEL, = get_conf('LLM_MODEL')
model_name, addr, port = LLM_MODEL.split('@')
# "TGUI:galactica-1.3b@localhost:7860"
model_name, addr_port = LLM_MODEL.split('@')
assert ':' in addr_port, "LLM_MODEL 格式不正确!" + LLM_MODEL
addr, port = addr_port.split(':')
def random_hash():
letters = string.ascii_lowercase + string.digits
@ -117,11 +120,11 @@ def predict_tgui(inputs, top_p, temperature, chatbot=[], history=[], system_prom
def run_coorotine(mutable):
async def get_result(mutable):
async for response in run(prompt):
# Print intermediate steps
print(response[len(mutable[0]):])
mutable[0] = response
asyncio.run(get_result(mutable))
thread_listen = threading.Thread(target=run_coorotine, args=(mutable,))
thread_listen = threading.Thread(target=run_coorotine, args=(mutable,), daemon=True)
thread_listen.start()
while thread_listen.is_alive():
@ -145,7 +148,7 @@ def predict_tgui_no_ui(inputs, top_p, temperature, history=[], sys_prompt=""):
def run_coorotine(mutable):
async def get_result(mutable):
async for response in run(prompt):
# Print intermediate steps
print(response[len(mutable[0]):])
mutable[0] = response
asyncio.run(get_result(mutable))
thread_listen = threading.Thread(target=run_coorotine, args=(mutable,))