解除本地模型的若干并发问题
This commit is contained in:
		
							parent
							
								
									17cf47dcd6
								
							
						
					
					
						commit
						09857ea455
					
				@ -3,11 +3,32 @@ import threading
 | 
			
		||||
from toolbox import update_ui
 | 
			
		||||
from multiprocessing import Process, Pipe
 | 
			
		||||
from contextlib import redirect_stdout
 | 
			
		||||
from request_llms.queued_pipe import create_queue_pipe
 | 
			
		||||
 | 
			
		||||
class DebugLock(object):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._lock = threading.Lock()
 | 
			
		||||
 | 
			
		||||
    def acquire(self):
 | 
			
		||||
        print("acquiring", self)
 | 
			
		||||
        #traceback.print_tb
 | 
			
		||||
        self._lock.acquire()
 | 
			
		||||
        print("acquired", self)
 | 
			
		||||
 | 
			
		||||
    def release(self):
 | 
			
		||||
        print("released", self)
 | 
			
		||||
        #traceback.print_tb
 | 
			
		||||
        self._lock.release()
 | 
			
		||||
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        self.acquire()
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, type, value, traceback):
 | 
			
		||||
        self.release()
 | 
			
		||||
 | 
			
		||||
def SingletonLocalLLM(cls):
 | 
			
		||||
    """
 | 
			
		||||
    一个单实例装饰器
 | 
			
		||||
    Singleton Decroator for LocalLLMHandle
 | 
			
		||||
    """
 | 
			
		||||
    _instance = {}
 | 
			
		||||
 | 
			
		||||
@ -46,24 +67,41 @@ def reset_tqdm_output():
 | 
			
		||||
 | 
			
		||||
class LocalLLMHandle(Process):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        # ⭐主进程执行
 | 
			
		||||
        # ⭐run in main process
 | 
			
		||||
        super().__init__(daemon=True)
 | 
			
		||||
        self.is_main_process = True # init
 | 
			
		||||
        self.corrupted = False
 | 
			
		||||
        self.load_model_info()
 | 
			
		||||
        self.parent, self.child = Pipe()
 | 
			
		||||
        self.parent, self.child = create_queue_pipe()
 | 
			
		||||
        self.parent_state, self.child_state = create_queue_pipe()
 | 
			
		||||
        # allow redirect_stdout
 | 
			
		||||
        self.std_tag = "[Subprocess Message] "
 | 
			
		||||
        self.child.write = lambda x: self.child.send(self.std_tag + x)
 | 
			
		||||
        self.running = True
 | 
			
		||||
        self._model = None
 | 
			
		||||
        self._tokenizer = None
 | 
			
		||||
        self.info = ""
 | 
			
		||||
        self.state = ""
 | 
			
		||||
        self.check_dependency()
 | 
			
		||||
        self.is_main_process = False    # state wrap for child process
 | 
			
		||||
        self.start()
 | 
			
		||||
        self.threadLock = threading.Lock()
 | 
			
		||||
        self.is_main_process = True     # state wrap for child process
 | 
			
		||||
        self.threadLock = DebugLock()
 | 
			
		||||
 | 
			
		||||
    def get_state(self):
 | 
			
		||||
        # ⭐run in main process
 | 
			
		||||
        while self.parent_state.poll():
 | 
			
		||||
            self.state = self.parent_state.recv()
 | 
			
		||||
        return self.state
 | 
			
		||||
 | 
			
		||||
    def set_state(self, new_state):
 | 
			
		||||
        # ⭐run in main process or 🏃♂️🏃♂️🏃♂️ run in child process 
 | 
			
		||||
        if self.is_main_process:
 | 
			
		||||
            self.state = new_state
 | 
			
		||||
        else:
 | 
			
		||||
            self.child_state.send(new_state)
 | 
			
		||||
 | 
			
		||||
    def load_model_info(self):
 | 
			
		||||
        # 🏃♂️🏃♂️🏃♂️ 子进程执行
 | 
			
		||||
        # 🏃♂️🏃♂️🏃♂️ run in child process
 | 
			
		||||
        raise NotImplementedError("Method not implemented yet")
 | 
			
		||||
        self.model_name = ""
 | 
			
		||||
        self.cmd_to_install = ""
 | 
			
		||||
@ -72,40 +110,40 @@ class LocalLLMHandle(Process):
 | 
			
		||||
        """
 | 
			
		||||
        This function should return the model and the tokenizer
 | 
			
		||||
        """
 | 
			
		||||
        # 🏃♂️🏃♂️🏃♂️ 子进程执行
 | 
			
		||||
        # 🏃♂️🏃♂️🏃♂️ run in child process
 | 
			
		||||
        raise NotImplementedError("Method not implemented yet")
 | 
			
		||||
 | 
			
		||||
    def llm_stream_generator(self, **kwargs):
 | 
			
		||||
        # 🏃♂️🏃♂️🏃♂️ 子进程执行
 | 
			
		||||
        # 🏃♂️🏃♂️🏃♂️ run in child process
 | 
			
		||||
        raise NotImplementedError("Method not implemented yet")
 | 
			
		||||
 | 
			
		||||
    def try_to_import_special_deps(self, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        import something that will raise error if the user does not install requirement_*.txt
 | 
			
		||||
        """
 | 
			
		||||
        # ⭐主进程执行
 | 
			
		||||
        # ⭐run in main process
 | 
			
		||||
        raise NotImplementedError("Method not implemented yet")
 | 
			
		||||
 | 
			
		||||
    def check_dependency(self):
 | 
			
		||||
        # ⭐主进程执行
 | 
			
		||||
        # ⭐run in main process
 | 
			
		||||
        try:
 | 
			
		||||
            self.try_to_import_special_deps()
 | 
			
		||||
            self.info = "`依赖检测通过`"
 | 
			
		||||
            self.set_state("`依赖检测通过`")
 | 
			
		||||
            self.running = True
 | 
			
		||||
        except:
 | 
			
		||||
            self.info = f"缺少{self.model_name}的依赖,如果要使用{self.model_name},除了基础的pip依赖以外,您还需要运行{self.cmd_to_install}安装{self.model_name}的依赖。"
 | 
			
		||||
            self.set_state(f"缺少{self.model_name}的依赖,如果要使用{self.model_name},除了基础的pip依赖以外,您还需要运行{self.cmd_to_install}安装{self.model_name}的依赖。")
 | 
			
		||||
            self.running = False
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
        # 🏃♂️🏃♂️🏃♂️ 子进程执行
 | 
			
		||||
        # 🏃♂️🏃♂️🏃♂️ run in child process
 | 
			
		||||
        # 第一次运行,加载参数
 | 
			
		||||
        reset_tqdm_output()
 | 
			
		||||
        self.info = "`尝试加载模型`"
 | 
			
		||||
        self.set_state("`尝试加载模型`")
 | 
			
		||||
        try:
 | 
			
		||||
            with redirect_stdout(self.child):
 | 
			
		||||
                self._model, self._tokenizer = self.load_model_and_tokenizer()
 | 
			
		||||
        except:
 | 
			
		||||
            self.info = "`加载模型失败`"
 | 
			
		||||
            self.set_state("`加载模型失败`")
 | 
			
		||||
            self.running = False
 | 
			
		||||
            from toolbox import trimmed_format_exc
 | 
			
		||||
            self.child.send(
 | 
			
		||||
@ -113,7 +151,7 @@ class LocalLLMHandle(Process):
 | 
			
		||||
            self.child.send('[FinishBad]')
 | 
			
		||||
            raise RuntimeError(f"不能正常加载{self.model_name}的参数!")
 | 
			
		||||
 | 
			
		||||
        self.info = "`准备就绪`"
 | 
			
		||||
        self.set_state("`准备就绪`")
 | 
			
		||||
        while True:
 | 
			
		||||
            # 进入任务等待状态
 | 
			
		||||
            kwargs = self.child.recv()
 | 
			
		||||
@ -121,6 +159,7 @@ class LocalLLMHandle(Process):
 | 
			
		||||
            try:
 | 
			
		||||
                for response_full in self.llm_stream_generator(**kwargs):
 | 
			
		||||
                    self.child.send(response_full)
 | 
			
		||||
                    print('debug' + response_full)
 | 
			
		||||
                self.child.send('[Finish]')
 | 
			
		||||
                # 请求处理结束,开始下一个循环
 | 
			
		||||
            except:
 | 
			
		||||
@ -129,18 +168,35 @@ class LocalLLMHandle(Process):
 | 
			
		||||
                    f'[Local Message] 调用{self.model_name}失败.' + '\n```\n' + trimmed_format_exc() + '\n```\n')
 | 
			
		||||
                self.child.send('[Finish]')
 | 
			
		||||
 | 
			
		||||
    def clear_pending_messages(self):
 | 
			
		||||
        # ⭐run in main process
 | 
			
		||||
        while True:
 | 
			
		||||
            if  self.parent.poll():
 | 
			
		||||
                self.parent.recv()
 | 
			
		||||
                continue
 | 
			
		||||
            for _ in range(5):
 | 
			
		||||
                time.sleep(0.5)
 | 
			
		||||
                if  self.parent.poll():
 | 
			
		||||
                    r = self.parent.recv()
 | 
			
		||||
                    continue
 | 
			
		||||
            break
 | 
			
		||||
        return 
 | 
			
		||||
    
 | 
			
		||||
    def stream_chat(self, **kwargs):
 | 
			
		||||
        # ⭐主进程执行
 | 
			
		||||
        if self.info == "`准备就绪`":
 | 
			
		||||
        # ⭐run in main process
 | 
			
		||||
        if self.get_state() == "`准备就绪`":
 | 
			
		||||
            yield "`正在等待线程锁,排队中请稍后 ...`"
 | 
			
		||||
 | 
			
		||||
        with self.threadLock:
 | 
			
		||||
            if self.parent.poll():
 | 
			
		||||
                while self.parent.poll(): self.parent.recv()
 | 
			
		||||
                yield "`排队中请稍后 ...`"
 | 
			
		||||
                self.clear_pending_messages()
 | 
			
		||||
            self.parent.send(kwargs)
 | 
			
		||||
            std_out = ""
 | 
			
		||||
            std_out_clip_len = 4096
 | 
			
		||||
            while True:
 | 
			
		||||
                res = self.parent.recv()
 | 
			
		||||
                # pipe_watch_dog.feed()
 | 
			
		||||
                if res.startswith(self.std_tag):
 | 
			
		||||
                    new_output = res[len(self.std_tag):]
 | 
			
		||||
                    std_out = std_out[:std_out_clip_len]
 | 
			
		||||
@ -157,20 +213,18 @@ class LocalLLMHandle(Process):
 | 
			
		||||
                    std_out = ""
 | 
			
		||||
                    yield res
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_local_llm_predict_fns(LLMSingletonClass, model_name, history_format='classic'):
 | 
			
		||||
    load_message = f"{model_name}尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,{model_name}消耗大量的内存(CPU)或显存(GPU),也许会导致低配计算机卡死 ……"
 | 
			
		||||
 | 
			
		||||
    def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
 | 
			
		||||
        """
 | 
			
		||||
            ⭐多线程方法
 | 
			
		||||
            函数的说明请见 request_llms/bridge_all.py
 | 
			
		||||
            refer to request_llms/bridge_all.py
 | 
			
		||||
        """
 | 
			
		||||
        _llm_handle = LLMSingletonClass()
 | 
			
		||||
        if len(observe_window) >= 1:
 | 
			
		||||
            observe_window[0] = load_message + "\n\n" + _llm_handle.info
 | 
			
		||||
            observe_window[0] = load_message + "\n\n" + _llm_handle.get_state()
 | 
			
		||||
        if not _llm_handle.running:
 | 
			
		||||
            raise RuntimeError(_llm_handle.info)
 | 
			
		||||
            raise RuntimeError(_llm_handle.get_state())
 | 
			
		||||
 | 
			
		||||
        if history_format == 'classic':
 | 
			
		||||
            # 没有 sys_prompt 接口,因此把prompt加入 history
 | 
			
		||||
@ -210,16 +264,15 @@ def get_local_llm_predict_fns(LLMSingletonClass, model_name, history_format='cla
 | 
			
		||||
 | 
			
		||||
    def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
 | 
			
		||||
        """
 | 
			
		||||
            ⭐单线程方法
 | 
			
		||||
            函数的说明请见 request_llms/bridge_all.py
 | 
			
		||||
            refer to request_llms/bridge_all.py
 | 
			
		||||
        """
 | 
			
		||||
        chatbot.append((inputs, ""))
 | 
			
		||||
 | 
			
		||||
        _llm_handle = LLMSingletonClass()
 | 
			
		||||
        chatbot[-1] = (inputs, load_message + "\n\n" + _llm_handle.info)
 | 
			
		||||
        chatbot[-1] = (inputs, load_message + "\n\n" + _llm_handle.get_state())
 | 
			
		||||
        yield from update_ui(chatbot=chatbot, history=[])
 | 
			
		||||
        if not _llm_handle.running:
 | 
			
		||||
            raise RuntimeError(_llm_handle.info)
 | 
			
		||||
            raise RuntimeError(_llm_handle.get_state())
 | 
			
		||||
 | 
			
		||||
        if additional_fn is not None:
 | 
			
		||||
            from core_functional import handle_core_functionality
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								request_llms/queued_pipe.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								request_llms/queued_pipe.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,24 @@
 | 
			
		||||
from multiprocessing import Pipe, Queue
 | 
			
		||||
import time
 | 
			
		||||
import threading
 | 
			
		||||
 | 
			
		||||
class PipeSide(object):
 | 
			
		||||
    def __init__(self, q_2remote, q_2local) -> None:
 | 
			
		||||
        self.q_2remote = q_2remote
 | 
			
		||||
        self.q_2local = q_2local
 | 
			
		||||
 | 
			
		||||
    def recv(self):
 | 
			
		||||
        return self.q_2local.get()
 | 
			
		||||
 | 
			
		||||
    def send(self, buf):
 | 
			
		||||
        self.q_2remote.put(buf)
 | 
			
		||||
 | 
			
		||||
    def poll(self):
 | 
			
		||||
        return not self.q_2local.empty()
 | 
			
		||||
 | 
			
		||||
def create_queue_pipe():
 | 
			
		||||
    q_p2c = Queue()
 | 
			
		||||
    q_c2p = Queue()
 | 
			
		||||
    pipe_c = PipeSide(q_2local=q_p2c, q_2remote=q_c2p)
 | 
			
		||||
    pipe_p = PipeSide(q_2local=q_c2p, q_2remote=q_p2c)
 | 
			
		||||
    return pipe_c, pipe_p
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user