diff --git a/.gitignore b/.gitignore index 7a9c92b..0dd68f8 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,4 @@ debug* private* crazy_functions/test_project/pdf_and_word crazy_functions/test_samples +request_llm/jittorllms \ No newline at end of file diff --git a/request_llm/bridge_jittorllms.py b/request_llm/bridge_jittorllms.py new file mode 100644 index 0000000..28d0a7a --- /dev/null +++ b/request_llm/bridge_jittorllms.py @@ -0,0 +1,153 @@ + +from transformers import AutoModel, AutoTokenizer +import time +import threading +import importlib +from toolbox import update_ui, get_conf +from multiprocessing import Process, Pipe + +load_message = "jittorllms尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,jittorllms消耗大量的内存(CPU)或显存(GPU),也许会导致低配计算机卡死 ……" + +################################################################################# +class GetGLMHandle(Process): + def __init__(self): + super().__init__(daemon=True) + self.parent, self.child = Pipe() + self.jittorllms_model = None + self.info = "" + self.success = True + self.check_dependency() + self.start() + self.threadLock = threading.Lock() + + def check_dependency(self): + try: + import jittor + from .jittorllms.models import get_model + self.info = "依赖检测通过" + self.success = True + except: + self.info = r"缺少jittorllms的依赖,如果要使用jittorllms,除了基础的pip依赖以外,您还需要运行`pip install -r request_llm/requirements_jittorllms.txt`"+\ + r"和`git clone https://gitlink.org.cn/jittor/JittorLLMs.git --depth 1 request_llm/jittorllms`两个指令来安装jittorllms的依赖(在项目根目录运行这两个指令)。" + self.success = False + + def ready(self): + return self.jittorllms_model is not None + + def run(self): + # 子进程执行 + # 第一次运行,加载参数 + def load_model(): + import types + try: + if self.jittorllms_model is None: + device, = get_conf('LOCAL_MODEL_DEVICE') + from .jittorllms.models import get_model + # availabel_models = ["chatglm", "pangualpha", "llama", "chatrwkv"] + args_dict = {'model': 'chatglm', 'RUN_DEVICE':'cpu'} + self.jittorllms_model = get_model(types.SimpleNamespace(**args_dict)) + except: + self.child.send('[Local Message] Call jittorllms fail 不能正常加载jittorllms的参数。') + raise RuntimeError("不能正常加载jittorllms的参数!") + + load_model() + + # 进入任务等待状态 + while True: + # 进入任务等待状态 + kwargs = self.child.recv() + # 收到消息,开始请求 + try: + for response, history in self.jittorllms_model.run_web_demo(kwargs['query'], kwargs['history']): + self.child.send(response) + except: + self.child.send('[Local Message] Call jittorllms fail.') + # 请求处理结束,开始下一个循环 + self.child.send('[Finish]') + + def stream_chat(self, **kwargs): + # 主进程执行 + self.threadLock.acquire() + self.parent.send(kwargs) + while True: + res = self.parent.recv() + if res != '[Finish]': + yield res + else: + break + self.threadLock.release() + +global glm_handle +glm_handle = None +################################################################################# +def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False): + """ + 多线程方法 + 函数的说明请见 request_llm/bridge_all.py + """ + global glm_handle + if glm_handle is None: + glm_handle = GetGLMHandle() + if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + glm_handle.info + if not glm_handle.success: + error = glm_handle.info + glm_handle = None + raise RuntimeError(error) + + # jittorllms 没有 sys_prompt 接口,因此把prompt加入 history + history_feedin = [] + history_feedin.append(["What can I do?", sys_prompt]) + for i in range(len(history)//2): + history_feedin.append([history[2*i], history[2*i+1]] ) + + watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可 + response = "" + for response in glm_handle.stream_chat(query=inputs, history=history_feedin, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']): + if len(observe_window) >= 1: observe_window[0] = response + if len(observe_window) >= 2: + if (time.time()-observe_window[1]) > watch_dog_patience: + raise RuntimeError("程序终止。") + return response + + + +def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None): + """ + 单线程方法 + 函数的说明请见 request_llm/bridge_all.py + """ + chatbot.append((inputs, "")) + + global glm_handle + if glm_handle is None: + glm_handle = GetGLMHandle() + chatbot[-1] = (inputs, load_message + "\n\n" + glm_handle.info) + yield from update_ui(chatbot=chatbot, history=[]) + if not glm_handle.success: + glm_handle = None + return + + if additional_fn is not None: + import core_functional + importlib.reload(core_functional) # 热更新prompt + core_functional = core_functional.get_core_functions() + if "PreProcess" in core_functional[additional_fn]: inputs = core_functional[additional_fn]["PreProcess"](inputs) # 获取预处理函数(如果有的话) + inputs = core_functional[additional_fn]["Prefix"] + inputs + core_functional[additional_fn]["Suffix"] + + # 处理历史信息 + history_feedin = [] + history_feedin.append(["What can I do?", system_prompt] ) + for i in range(len(history)//2): + history_feedin.append([history[2*i], history[2*i+1]] ) + + # 开始接收jittorllms的回复 + response = "[Local Message]: 等待jittorllms响应中 ..." + for response in glm_handle.stream_chat(query=inputs, history=history_feedin, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']): + chatbot[-1] = (inputs, response) + yield from update_ui(chatbot=chatbot, history=history) + + # 总结输出 + if response == "[Local Message]: 等待jittorllms响应中 ...": + response = "[Local Message]: jittorllms响应异常 ..." + history.extend([inputs, response]) + yield from update_ui(chatbot=chatbot, history=history) diff --git a/request_llm/requirements_jittorllms.txt b/request_llm/requirements_jittorllms.txt new file mode 100644 index 0000000..3713ce8 --- /dev/null +++ b/request_llm/requirements_jittorllms.txt @@ -0,0 +1,4 @@ +jittor >= 1.3.7.9 +jtorch >= 0.1.3 +torch +torchvision \ No newline at end of file diff --git a/request_llm/test_llms.py b/request_llm/test_llms.py new file mode 100644 index 0000000..d043d62 --- /dev/null +++ b/request_llm/test_llms.py @@ -0,0 +1,26 @@ +""" +对各个llm模型进行单元测试 +""" +def validate_path(): + import os, sys + dir_name = os.path.dirname(__file__) + root_dir_assume = os.path.abspath(os.path.dirname(__file__) + '/..') + os.chdir(root_dir_assume) + sys.path.append(root_dir_assume) + +validate_path() # validate path so you can run from base directory + +from request_llm.bridge_jittorllms import predict_no_ui_long_connection + +llm_kwargs = { + 'max_length': 512, + 'top_p': 1, + 'temperature': 1, +} + +result = predict_no_ui_long_connection(inputs="你好", + llm_kwargs=llm_kwargs, + history=[], + sys_prompt="") + +print('result') \ No newline at end of file