Interface with LLaMa2 from huggingface

This commit is contained in:
binary-husky 2023-08-21 21:54:21 +08:00
parent 8b3b883fce
commit 9720bec5e5
4 changed files with 113 additions and 2 deletions

View File

@ -27,7 +27,7 @@ To translate this project to arbitary language with GPT, read and run [`multi_la
功能(⭐= 近期新增功能) | 描述 功能(⭐= 近期新增功能) | 描述
--- | --- --- | ---
⭐[接入新模型](https://github.com/binary-husky/gpt_academic/wiki/%E5%A6%82%E4%BD%95%E5%88%87%E6%8D%A2%E6%A8%A1%E5%9E%8B) | ⭐阿里达摩院[通义千问](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary)上海AI-Lab[书生](https://github.com/InternLM/InternLM),讯飞[星火](https://xinghuo.xfyun.cn/) ⭐[接入新模型](https://github.com/binary-husky/gpt_academic/wiki/%E5%A6%82%E4%BD%95%E5%88%87%E6%8D%A2%E6%A8%A1%E5%9E%8B) | ⭐阿里达摩院[通义千问](https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary)上海AI-Lab[书生](https://github.com/InternLM/InternLM),讯飞[星火](https://xinghuo.xfyun.cn/)[LLaMa2](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
一键润色 | 支持一键润色、一键查找论文语法错误 一键润色 | 支持一键润色、一键查找论文语法错误
一键中英互译 | 一键中英互译 一键中英互译 | 一键中英互译
一键代码解释 | 显示代码、解释代码、生成代码、给代码加注释 一键代码解释 | 显示代码、解释代码、生成代码、给代码加注释

View File

@ -150,3 +150,7 @@ ANTHROPIC_API_KEY = ""
# 自定义API KEY格式 # 自定义API KEY格式
CUSTOM_API_KEY_PATTERN = "" CUSTOM_API_KEY_PATTERN = ""
# HUGGINGFACE的TOKEN 下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
HUGGINGFACE_ACCESS_TOKEN = "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"

View File

@ -385,6 +385,22 @@ if "spark" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
}) })
except: except:
print(trimmed_format_exc()) print(trimmed_format_exc())
if "llama2" in AVAIL_LLM_MODELS: # 讯飞星火认知大模型
try:
from .bridge_llama2 import predict_no_ui_long_connection as llama2_noui
from .bridge_llama2 import predict as llama2_ui
model_info.update({
"llama2": {
"fn_with_ui": llama2_ui,
"fn_without_ui": llama2_noui,
"endpoint": None,
"max_token": 4096,
"tokenizer": tokenizer_gpt35,
"token_cnt": get_token_num_gpt35,
}
})
except:
print(trimmed_format_exc())

View File

@ -0,0 +1,91 @@
model_name = "LLaMA"
cmd_to_install = "`pip install -r request_llm/requirements_chatglm.txt`"
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from toolbox import update_ui, get_conf, ProxyNetworkActivate
from multiprocessing import Process, Pipe
from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns, SingletonLocalLLM
from threading import Thread
# ------------------------------------------------------------------------------------------------------------------------
# 🔌💻 Local Model
# ------------------------------------------------------------------------------------------------------------------------
@SingletonLocalLLM
class GetONNXGLMHandle(LocalLLMHandle):
def load_model_info(self):
# 🏃‍♂️🏃‍♂️🏃‍♂️ 子进程执行
self.model_name = model_name
self.cmd_to_install = cmd_to_install
def load_model_and_tokenizer(self):
# 🏃‍♂️🏃‍♂️🏃‍♂️ 子进程执行
import os, glob
import os
import platform
huggingface_token, device = get_conf('HUGGINGFACE_ACCESS_TOKEN', 'LOCAL_MODEL_DEVICE')
assert len(huggingface_token) != 0, "没有填写 HUGGINGFACE_ACCESS_TOKEN"
with open(os.path.expanduser('~/.cache/huggingface/token'), 'w') as f:
f.write(huggingface_token)
model_id = 'meta-llama/Llama-2-7b-chat-hf'
with ProxyNetworkActivate():
self._tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=huggingface_token)
# use fp16
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=huggingface_token).eval()
if device.startswith('cuda'): model = model.half().to(device)
self._model = model
return self._model, self._tokenizer
def llm_stream_generator(self, **kwargs):
# 🏃‍♂️🏃‍♂️🏃‍♂️ 子进程执行
def adaptor(kwargs):
query = kwargs['query']
max_length = kwargs['max_length']
top_p = kwargs['top_p']
temperature = kwargs['temperature']
history = kwargs['history']
console_slience = kwargs.get('console_slience', True)
return query, max_length, top_p, temperature, history, console_slience
def convert_messages_to_prompt(query, history):
prompt = ""
for a, b in history:
prompt += f"\n[INST]{a}[/INST]"
prompt += "\n{b}" + b
prompt += f"\n[INST]{query}[/INST]"
return prompt
query, max_length, top_p, temperature, history, console_slience = adaptor(kwargs)
prompt = convert_messages_to_prompt(query, history)
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-
# code from transformers.llama
streamer = TextIteratorStreamer(self._tokenizer)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
inputs = self._tokenizer([prompt], return_tensors="pt")
prompt_tk_back = self._tokenizer.batch_decode(inputs['input_ids'])[0]
generation_kwargs = dict(inputs.to(self._model.device), streamer=streamer, max_new_tokens=max_length)
thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
if not console_slience: print(new_text, end='')
yield generated_text.lstrip(prompt_tk_back).rstrip("</s>")
if not console_slience: print()
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-
def try_to_import_special_deps(self, **kwargs):
# import something that will raise error if the user does not install requirement_*.txt
# 🏃‍♂️🏃‍♂️🏃‍♂️ 主进程执行
import importlib
importlib.import_module('transformers')
# ------------------------------------------------------------------------------------------------------------------------
# 🔌💻 GPT-Academic Interface
# ------------------------------------------------------------------------------------------------------------------------
predict_no_ui_long_connection, predict = get_local_llm_predict_fns(GetONNXGLMHandle, model_name)