From 2d91e438d658220c6b366aecf3aaa81e09eb75c4 Mon Sep 17 00:00:00 2001 From: qingxu fu <505030475@qq.com> Date: Sat, 11 Nov 2023 23:22:50 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=AD=A3internlm=E8=BE=93=E5=85=A5?= =?UTF-8?q?=E8=AE=BE=E5=A4=87bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- request_llms/bridge_internlm.py | 3 ++- request_llms/local_llm_class.py | 33 +++++++++++++++------------------ 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/request_llms/bridge_internlm.py b/request_llms/bridge_internlm.py index 20b53b4..b2be36a 100644 --- a/request_llms/bridge_internlm.py +++ b/request_llms/bridge_internlm.py @@ -94,8 +94,9 @@ class GetInternlmHandle(LocalLLMHandle): inputs = tokenizer([prompt], padding=True, return_tensors="pt") input_length = len(inputs["input_ids"][0]) + device = get_conf('LOCAL_MODEL_DEVICE') for k, v in inputs.items(): - inputs[k] = v.cuda() + inputs[k] = v.to(device) input_ids = inputs["input_ids"] batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] if generation_config is None: diff --git a/request_llms/local_llm_class.py b/request_llms/local_llm_class.py index 38fcfc9..413df03 100644 --- a/request_llms/local_llm_class.py +++ b/request_llms/local_llm_class.py @@ -1,6 +1,6 @@ import time import threading -from toolbox import update_ui +from toolbox import update_ui, Singleton from multiprocessing import Process, Pipe from contextlib import redirect_stdout from request_llms.queued_pipe import create_queue_pipe @@ -26,23 +26,20 @@ class ThreadLock(object): def __exit__(self, type, value, traceback): self.release() -def SingletonLocalLLM(cls): - """ - Singleton Decroator for LocalLLMHandle - """ - _instance = {} +@Singleton +class GetSingletonHandle(): + def __init__(self): + self.llm_model_already_running = {} - def _singleton(*args, **kargs): - if cls not in _instance: - _instance[cls] = cls(*args, **kargs) - return _instance[cls] - elif _instance[cls].corrupted: - _instance[cls] = cls(*args, **kargs) - return _instance[cls] + def get_llm_model_instance(self, cls, *args, **kargs): + if cls not in self.llm_model_already_running: + self.llm_model_already_running[cls] = cls(*args, **kargs) + return self.llm_model_already_running[cls] + elif self.llm_model_already_running[cls].corrupted: + self.llm_model_already_running[cls] = cls(*args, **kargs) + return self.llm_model_already_running[cls] else: - return _instance[cls] - return _singleton - + return self.llm_model_already_running[cls] def reset_tqdm_output(): import sys, tqdm @@ -221,7 +218,7 @@ def get_local_llm_predict_fns(LLMSingletonClass, model_name, history_format='cla """ refer to request_llms/bridge_all.py """ - _llm_handle = SingletonLocalLLM(LLMSingletonClass)() + _llm_handle = GetSingletonHandle().get_llm_model_instance(LLMSingletonClass) if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + _llm_handle.get_state() if not _llm_handle.running: @@ -269,7 +266,7 @@ def get_local_llm_predict_fns(LLMSingletonClass, model_name, history_format='cla """ chatbot.append((inputs, "")) - _llm_handle = SingletonLocalLLM(LLMSingletonClass)() + _llm_handle = GetSingletonHandle().get_llm_model_instance(LLMSingletonClass) chatbot[-1] = (inputs, load_message + "\n\n" + _llm_handle.get_state()) yield from update_ui(chatbot=chatbot, history=[]) if not _llm_handle.running: