Merge branch 'frontier' into master_autogen

This commit is contained in:
qingxu fu 2023-11-11 23:24:21 +08:00
commit 1335da4f45
2 changed files with 17 additions and 19 deletions

View File

@ -94,8 +94,9 @@ class GetInternlmHandle(LocalLLMHandle):
inputs = tokenizer([prompt], padding=True, return_tensors="pt") inputs = tokenizer([prompt], padding=True, return_tensors="pt")
input_length = len(inputs["input_ids"][0]) input_length = len(inputs["input_ids"][0])
device = get_conf('LOCAL_MODEL_DEVICE')
for k, v in inputs.items(): for k, v in inputs.items():
inputs[k] = v.cuda() inputs[k] = v.to(device)
input_ids = inputs["input_ids"] input_ids = inputs["input_ids"]
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None: if generation_config is None:

View File

@ -1,6 +1,6 @@
import time import time
import threading import threading
from toolbox import update_ui from toolbox import update_ui, Singleton
from multiprocessing import Process, Pipe from multiprocessing import Process, Pipe
from contextlib import redirect_stdout from contextlib import redirect_stdout
from request_llms.queued_pipe import create_queue_pipe from request_llms.queued_pipe import create_queue_pipe
@ -26,23 +26,20 @@ class ThreadLock(object):
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
self.release() self.release()
def SingletonLocalLLM(cls): @Singleton
""" class GetSingletonHandle():
Singleton Decroator for LocalLLMHandle def __init__(self):
""" self.llm_model_already_running = {}
_instance = {}
def _singleton(*args, **kargs): def get_llm_model_instance(self, cls, *args, **kargs):
if cls not in _instance: if cls not in self.llm_model_already_running:
_instance[cls] = cls(*args, **kargs) self.llm_model_already_running[cls] = cls(*args, **kargs)
return _instance[cls] return self.llm_model_already_running[cls]
elif _instance[cls].corrupted: elif self.llm_model_already_running[cls].corrupted:
_instance[cls] = cls(*args, **kargs) self.llm_model_already_running[cls] = cls(*args, **kargs)
return _instance[cls] return self.llm_model_already_running[cls]
else: else:
return _instance[cls] return self.llm_model_already_running[cls]
return _singleton
def reset_tqdm_output(): def reset_tqdm_output():
import sys, tqdm import sys, tqdm
@ -221,7 +218,7 @@ def get_local_llm_predict_fns(LLMSingletonClass, model_name, history_format='cla
""" """
refer to request_llms/bridge_all.py refer to request_llms/bridge_all.py
""" """
_llm_handle = SingletonLocalLLM(LLMSingletonClass)() _llm_handle = GetSingletonHandle().get_llm_model_instance(LLMSingletonClass)
if len(observe_window) >= 1: if len(observe_window) >= 1:
observe_window[0] = load_message + "\n\n" + _llm_handle.get_state() observe_window[0] = load_message + "\n\n" + _llm_handle.get_state()
if not _llm_handle.running: if not _llm_handle.running:
@ -269,7 +266,7 @@ def get_local_llm_predict_fns(LLMSingletonClass, model_name, history_format='cla
""" """
chatbot.append((inputs, "")) chatbot.append((inputs, ""))
_llm_handle = SingletonLocalLLM(LLMSingletonClass)() _llm_handle = GetSingletonHandle().get_llm_model_instance(LLMSingletonClass)
chatbot[-1] = (inputs, load_message + "\n\n" + _llm_handle.get_state()) chatbot[-1] = (inputs, load_message + "\n\n" + _llm_handle.get_state())
yield from update_ui(chatbot=chatbot, history=[]) yield from update_ui(chatbot=chatbot, history=[])
if not _llm_handle.running: if not _llm_handle.running: