merge success
This commit is contained in:
parent
babb775cfb
commit
fd549fb986
@ -70,7 +70,7 @@ MAX_RETRY = 2
|
|||||||
|
|
||||||
# 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
|
# 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
|
||||||
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
|
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
|
||||||
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss", "newbing", "stack-claude"]
|
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2d-gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "chatglm", "moss", "internlm", "newbing", "stack-claude"]
|
||||||
# P.S. 其他可用的模型还包括 ["gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "claude-1-100k", "claude-2", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"]
|
# P.S. 其他可用的模型还包括 ["gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "claude-1-100k", "claude-2", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -248,7 +248,6 @@ if "moss" in AVAIL_LLM_MODELS:
|
|||||||
if "stack-claude" in AVAIL_LLM_MODELS:
|
if "stack-claude" in AVAIL_LLM_MODELS:
|
||||||
from .bridge_stackclaude import predict_no_ui_long_connection as claude_noui
|
from .bridge_stackclaude import predict_no_ui_long_connection as claude_noui
|
||||||
from .bridge_stackclaude import predict as claude_ui
|
from .bridge_stackclaude import predict as claude_ui
|
||||||
# claude
|
|
||||||
model_info.update({
|
model_info.update({
|
||||||
"stack-claude": {
|
"stack-claude": {
|
||||||
"fn_with_ui": claude_ui,
|
"fn_with_ui": claude_ui,
|
||||||
@ -263,7 +262,6 @@ if "newbing-free" in AVAIL_LLM_MODELS:
|
|||||||
try:
|
try:
|
||||||
from .bridge_newbingfree import predict_no_ui_long_connection as newbingfree_noui
|
from .bridge_newbingfree import predict_no_ui_long_connection as newbingfree_noui
|
||||||
from .bridge_newbingfree import predict as newbingfree_ui
|
from .bridge_newbingfree import predict as newbingfree_ui
|
||||||
# claude
|
|
||||||
model_info.update({
|
model_info.update({
|
||||||
"newbing-free": {
|
"newbing-free": {
|
||||||
"fn_with_ui": newbingfree_ui,
|
"fn_with_ui": newbingfree_ui,
|
||||||
@ -280,7 +278,6 @@ if "newbing" in AVAIL_LLM_MODELS: # same with newbing-free
|
|||||||
try:
|
try:
|
||||||
from .bridge_newbingfree import predict_no_ui_long_connection as newbingfree_noui
|
from .bridge_newbingfree import predict_no_ui_long_connection as newbingfree_noui
|
||||||
from .bridge_newbingfree import predict as newbingfree_ui
|
from .bridge_newbingfree import predict as newbingfree_ui
|
||||||
# claude
|
|
||||||
model_info.update({
|
model_info.update({
|
||||||
"newbing": {
|
"newbing": {
|
||||||
"fn_with_ui": newbingfree_ui,
|
"fn_with_ui": newbingfree_ui,
|
||||||
@ -297,7 +294,6 @@ if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free
|
|||||||
try:
|
try:
|
||||||
from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui
|
from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui
|
||||||
from .bridge_chatglmft import predict as chatglmft_ui
|
from .bridge_chatglmft import predict as chatglmft_ui
|
||||||
# claude
|
|
||||||
model_info.update({
|
model_info.update({
|
||||||
"chatglmft": {
|
"chatglmft": {
|
||||||
"fn_with_ui": chatglmft_ui,
|
"fn_with_ui": chatglmft_ui,
|
||||||
@ -310,7 +306,22 @@ if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free
|
|||||||
})
|
})
|
||||||
except:
|
except:
|
||||||
print(trimmed_format_exc())
|
print(trimmed_format_exc())
|
||||||
|
if "internlm" in AVAIL_LLM_MODELS:
|
||||||
|
try:
|
||||||
|
from .bridge_internlm import predict_no_ui_long_connection as internlm_noui
|
||||||
|
from .bridge_internlm import predict as internlm_ui
|
||||||
|
model_info.update({
|
||||||
|
"internlm": {
|
||||||
|
"fn_with_ui": internlm_ui,
|
||||||
|
"fn_without_ui": internlm_noui,
|
||||||
|
"endpoint": None,
|
||||||
|
"max_token": 4096,
|
||||||
|
"tokenizer": tokenizer_gpt35,
|
||||||
|
"token_cnt": get_token_num_gpt35,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
except:
|
||||||
|
print(trimmed_format_exc())
|
||||||
|
|
||||||
def LLM_CATCH_EXCEPTION(f):
|
def LLM_CATCH_EXCEPTION(f):
|
||||||
"""
|
"""
|
||||||
|
@ -12,6 +12,22 @@ load_message = f"{model_name}尚未加载,加载需要一段时间。注意,
|
|||||||
def try_to_import_special_deps():
|
def try_to_import_special_deps():
|
||||||
import sentencepiece
|
import sentencepiece
|
||||||
|
|
||||||
|
user_prompt = "<|User|>:{user}<eoh>\n"
|
||||||
|
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
|
||||||
|
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
|
||||||
|
|
||||||
|
|
||||||
|
def combine_history(prompt, hist):
|
||||||
|
messages = hist
|
||||||
|
total_prompt = ""
|
||||||
|
for message in messages:
|
||||||
|
cur_content = message
|
||||||
|
cur_prompt = user_prompt.replace("{user}", cur_content[0])
|
||||||
|
total_prompt += cur_prompt
|
||||||
|
cur_prompt = robot_prompt.replace("{robot}", cur_content[1])
|
||||||
|
total_prompt += cur_prompt
|
||||||
|
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
|
||||||
|
return total_prompt
|
||||||
|
|
||||||
|
|
||||||
@Singleton
|
@Singleton
|
||||||
@ -44,10 +60,10 @@ class GetInternlmHandle(Process):
|
|||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()
|
model = AutoModelForCausalLM.from_pretrained("internlm/internlm-chat-7b", trust_remote_code=True).to(torch.bfloat16).cuda()
|
||||||
|
|
||||||
self._model = self._model.eval()
|
model = model.eval()
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def llm_stream_generator(self, kwargs):
|
def llm_stream_generator(self, **kwargs):
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import copy
|
import copy
|
||||||
@ -63,12 +79,16 @@ class GetInternlmHandle(Process):
|
|||||||
max_length = kwargs['max_length']
|
max_length = kwargs['max_length']
|
||||||
top_p = kwargs['top_p']
|
top_p = kwargs['top_p']
|
||||||
temperature = kwargs['temperature']
|
temperature = kwargs['temperature']
|
||||||
return model, tokenizer, prompt, max_length, top_p, temperature
|
history = kwargs['history']
|
||||||
|
real_prompt = combine_history(prompt, history)
|
||||||
|
return model, tokenizer, real_prompt, max_length, top_p, temperature
|
||||||
|
|
||||||
model, tokenizer, prompt, max_length, top_p, temperature = adaptor()
|
model, tokenizer, prompt, max_length, top_p, temperature = adaptor()
|
||||||
prefix_allowed_tokens_fn = None
|
prefix_allowed_tokens_fn = None
|
||||||
logger = logging.get_logger(__name__)
|
logits_processor = None
|
||||||
|
stopping_criteria = None
|
||||||
additional_eos_token_id = 103028
|
additional_eos_token_id = 103028
|
||||||
|
generation_config = None
|
||||||
# 🏃♂️🏃♂️🏃♂️ 子进程执行
|
# 🏃♂️🏃♂️🏃♂️ 子进程执行
|
||||||
# 🏃♂️🏃♂️🏃♂️ https://github.com/InternLM/InternLM/blob/efbf5335709a8c8faeac6eaf07193973ff1d56a1/web_demo.py#L25
|
# 🏃♂️🏃♂️🏃♂️ https://github.com/InternLM/InternLM/blob/efbf5335709a8c8faeac6eaf07193973ff1d56a1/web_demo.py#L25
|
||||||
|
|
||||||
@ -98,7 +118,7 @@ class GetInternlmHandle(Process):
|
|||||||
elif generation_config.max_new_tokens is not None:
|
elif generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
if not has_default_max_length:
|
if not has_default_max_length:
|
||||||
logger.warn(
|
logging.warn(
|
||||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||||
"Please refer to the documentation for more information. "
|
"Please refer to the documentation for more information. "
|
||||||
@ -108,7 +128,7 @@ class GetInternlmHandle(Process):
|
|||||||
|
|
||||||
if input_ids_seq_length >= generation_config.max_length:
|
if input_ids_seq_length >= generation_config.max_length:
|
||||||
input_ids_string = "input_ids"
|
input_ids_string = "input_ids"
|
||||||
logger.warning(
|
logging.warning(
|
||||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||||
" increasing `max_new_tokens`."
|
" increasing `max_new_tokens`."
|
||||||
|
@ -23,45 +23,8 @@ if __name__ == "__main__":
|
|||||||
'temperature': 1,
|
'temperature': 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = predict_no_ui_long_connection(inputs="你好",
|
result = predict_no_ui_long_connection( inputs="请问什么是质子?",
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
history=[],
|
history=["你好", "我好!"],
|
||||||
sys_prompt="")
|
sys_prompt="")
|
||||||
print('final result:', result)
|
print('final result:', result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# # print(result)
|
|
||||||
# from multiprocessing import Process, Pipe
|
|
||||||
# class GetGLMHandle(Process):
|
|
||||||
# def __init__(self):
|
|
||||||
# super().__init__(daemon=True)
|
|
||||||
# pass
|
|
||||||
# def run(self):
|
|
||||||
# # 子进程执行
|
|
||||||
# # 第一次运行,加载参数
|
|
||||||
# def validate_path():
|
|
||||||
# import os, sys
|
|
||||||
# dir_name = os.path.dirname(__file__)
|
|
||||||
# root_dir_assume = os.path.abspath(os.path.dirname(__file__) + '/..')
|
|
||||||
# os.chdir(root_dir_assume + '/request_llm/jittorllms')
|
|
||||||
# sys.path.append(root_dir_assume + '/request_llm/jittorllms')
|
|
||||||
# validate_path() # validate path so you can run from base directory
|
|
||||||
# jittorllms_model = None
|
|
||||||
# import types
|
|
||||||
# try:
|
|
||||||
# if jittorllms_model is None:
|
|
||||||
# from models import get_model
|
|
||||||
# # availabel_models = ["chatglm", "pangualpha", "llama", "chatrwkv"]
|
|
||||||
# args_dict = {'model': 'chatrwkv'}
|
|
||||||
# print('self.jittorllms_model = get_model(types.SimpleNamespace(**args_dict))')
|
|
||||||
# jittorllms_model = get_model(types.SimpleNamespace(**args_dict))
|
|
||||||
# print('done get model')
|
|
||||||
# except:
|
|
||||||
# # self.child.send('[Local Message] Call jittorllms fail 不能正常加载jittorllms的参数。')
|
|
||||||
# raise RuntimeError("不能正常加载jittorllms的参数!")
|
|
||||||
# x = GetGLMHandle()
|
|
||||||
# x.start()
|
|
||||||
|
|
||||||
|
|
||||||
# input()
|
|
Loading…
x
Reference in New Issue
Block a user