import re import threading from toolbox import update_ui, get_conf from multiprocessing import Process, Pipe import numpy as np from onnxruntime import InferenceSession, SessionOptions from sentencepiece import SentencePieceProcessor # 模型来源 K024/ChatGLM-6b-onnx-u8s8 global glm_onnx_handle glm_onnx_handle = None load_message = "ChatGLM_onnx尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,ChatGLM_onnx消耗大量的内存(CPU)或显存(GPU),也许会导致低配(内存<8GB)计算机卡死 ……" # Default paths tokenizer_path = "YOUR/TOKENIZER_PATH/sentencepiece.model" onnx_model_path = "YOUR/TOKENIZER_PATH/chatglm-6b-int8.onnx" # Currently `MatMulInteger` and `DynamicQuantizeLinear` are only supported on CPU, # although they are documented as supported on CUDA. providers = ["CPUExecutionProvider"] # if torch.cuda.is_available(): # providers = ["CUDAExecutionProvider"] + providers ################################################################################# class GetGLMHandle(Process): def __init__(self): super().__init__(daemon=True) self.parent, self.child = Pipe() self.ChatGLM_onnx_model = None # tokenizer_path self.ChatGLM_onnx_tokenizer = None # onnx_model_path self.info = "" self.success = True self.check_dependency() self.start() self.threadLock = threading.Lock() def check_dependency(self): try: import sentencepiece self.info = "依赖检测通过" self.success = True except: self.info = "缺少ChatGLM_onnx的依赖,如果要使用ChatGLM_onnx,除了基础的pip依赖以外,您还需要运行`pip install -r request_llm/requirements_ChatGLM_onnx.txt`安装ChatGLM_onnx的依赖。" self.success = False def ready(self): return self.ChatGLM_onnx_model is not None def run(self): # 子进程执行 # 第一次运行,加载参数 retry = 0 while True: try: if self.ChatGLM_onnx_model is None: # Initialize the ChatGLMModel and ChatGLMTokenizer self.ChatGLM_onnx_model = ChatGLMModel() self.ChatGLM_onnx_tokenizer = ChatGLMTokenizer() break else: break except: retry += 1 if retry > 3: self.child.send('[Local Message] Call ChatGLM_onnx fail 不能正常加载ChatGLM_onnx的参数。') raise RuntimeError("不能正常加载ChatGLM_onnx的参数!") while True: # 进入任务等待状态 kwargs = self.child.recv() # 收到消息,开始请求 try: # Use the ChatGLMModel and ChatGLMTokenizer to generate a response response = tuple(self.ChatGLM_onnx_model.generate_iterate(kwargs['query'])) # Send the output data self.child.send(response[-1]) except: from toolbox import trimmed_format_exc self.child.send('[Local Message] Call ChatGLM_onnx fail.' + '\n```\n' + trimmed_format_exc() + '\n```\n') # 请求处理结束,开始下一个循环 self.child.send('[Finish]') def stream_chat(self, **kwargs): # 主进程执行 self.threadLock.acquire() self.parent.send(kwargs) while True: res = self.parent.recv() if res != '[Finish]': yield res else: break self.threadLock.release() ################################################################################# class ChatGLMModel(): def __init__(self, onnx_model_path=onnx_model_path, tokenizer_path=tokenizer_path, profile=False) -> None: self.tokenizer = ChatGLMTokenizer(tokenizer_path) options = SessionOptions() options.enable_profiling = profile self.session = InferenceSession(onnx_model_path, options, providers=providers) self.eop_token_id = self.tokenizer[""] # input & output names self.past_names = [f"past_{name}_{i}" for i in range(28) for name in ["key", "value"]] self.present_names = [f"present_{name}_{i}" for i in range(28) for name in ["key", "value"]] self.output_names = ["logits"] + self.present_names # default kv_cache for first inference self.default_past_key_values = { k: np.zeros((1, 0, 32, 128), dtype=np.float32) for k in self.past_names } def prepare_input(self, prompt: str): input_ids, prefix_mask = self.tokenizer.encode(prompt) input_ids = np.array([input_ids], dtype=np.longlong) prefix_mask = np.array([prefix_mask], dtype=np.longlong) return input_ids, prefix_mask, self.default_past_key_values def sample_next_token(self, logits: np.ndarray, top_k=50, top_p=0.7, temperature=1): # softmax with temperature exp_logits = np.exp(logits / temperature) probs = exp_logits / np.sum(exp_logits) # top k top_k_idx = np.argsort(-probs)[:top_k] top_k_probs = probs[top_k_idx] # top p cumsum_probs = np.cumsum(top_k_probs) top_k_probs[(cumsum_probs - top_k_probs) > top_p] = 0.0 top_k_probs = top_k_probs / np.sum(top_k_probs) # sample next_token = np.random.choice(top_k_idx, size=1, p=top_k_probs) return next_token[0].item() def generate_iterate(self, prompt: str, max_generated_tokens=100, top_k=50, top_p=0.7, temperature=1): input_ids, prefix_mask, past_key_values = self.prepare_input(prompt) output_tokens = [] while True: inputs = { "input_ids": input_ids, "prefix_mask": prefix_mask, "use_past": np.array(len(output_tokens) > 0), } inputs.update(past_key_values) logits, *past_key_values = self.session.run(self.output_names, inputs) past_key_values = { k: v for k, v in zip(self.past_names, past_key_values) } next_token = self.sample_next_token(logits[0, -1], top_k=top_k, top_p=top_p, temperature=temperature) output_tokens += [next_token] if next_token == self.eop_token_id or len(output_tokens) > max_generated_tokens: break input_ids = np.array([[next_token]], dtype=np.longlong) prefix_mask = np.concatenate([prefix_mask, np.array([[0]], dtype=np.longlong)], axis=1) yield process_response(self.tokenizer.decode(output_tokens)) return process_response(self.tokenizer.decode(output_tokens)) class ChatGLMTokenizer: def __init__(self, vocab_file): assert vocab_file is not None self.vocab_file = vocab_file self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] self.text_tokenizer = SentencePieceProcessor(str(vocab_file)) def __len__(self): return len(self.text_tokenizer) def __getitem__(self, key: str): return self.text_tokenizer[key] def preprocess(self, text: str, linebreak=True, whitespaces=True): if linebreak: text = text.replace("\\n", "") if whitespaces: text = text.replace("\\t", "<|tab|>") text = re.sub(r" {2,80}", self.replace_spaces_with_blank, text) return text def encode( self, text: str, text_pair: str = None, linebreak=True, whitespaces=True, add_dummy_prefix=True, special_tokens=True, ) -> tuple[list[int], list[int]]: """ text: Text to encode. Bidirectional part with a [gMASK] and an for causal LM. text_pair: causal LM part. linebreak: Whether to encode newline (\n) in text. whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. add_dummy_prefix: Whether to add dummy blank space in the beginning. """ text = self.preprocess(text, linebreak, whitespaces) if not add_dummy_prefix: text = "" + text tokens = self.text_tokenizer.encode(text) prefix_mask = [1] * len(tokens) if special_tokens: tokens += [self.text_tokenizer["[gMASK]"], self.text_tokenizer[""]] prefix_mask += [1, 0] if text_pair is not None: text_pair = self.preprocess(text_pair, linebreak, whitespaces) pair_tokens = self.text_tokenizer.encode(text_pair) tokens += pair_tokens prefix_mask += [0] * len(pair_tokens) if special_tokens: tokens += [self.text_tokenizer[""]] prefix_mask += [0] return (tokens if add_dummy_prefix else tokens[2:]), prefix_mask def decode(self, text_ids: list[int]) -> str: text = self.text_tokenizer.decode(text_ids) text = text.replace("", "\n") text = text.replace("<|tab|>", "\t") text = re.sub(r"<\|blank_(\d\d?)\|>", self.replace_blank_with_spaces, text) return text def replace_spaces_with_blank(match: re.Match[str]): return f"<|blank_{len(match.group())}|>" def replace_blank_with_spaces(match: re.Match[str]): return " " * int(match.group(1)) ################################################################################# def chat_template(history: list[tuple[str, str]], current: str): prompt = "" chat_round = 0 for question, answer in history: prompt += f"[Round {chat_round}]\n问:{question}\n答:{answer}\n" chat_round += 1 prompt += f"[Round {chat_round}]\n问:{current}\n答:" return prompt def process_response(response: str): response = response.strip() response = response.replace("[[训练时间]]", "2023年") punkts = [ [",", ","], ["!", "!"], [":", ":"], [";", ";"], ["\?", "?"], ] for item in punkts: response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) return response ################################################################################# def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False): """ 多线程方法 函数的说明请见 request_llm/bridge_all.py """ if glm_onnx_handle is None: glm_onnx_handle = GetGLMHandle() if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + glm_onnx_handle.info if not glm_onnx_handle.success: error = glm_onnx_handle.info glm_onnx_handle = None raise RuntimeError(error) # ChatGLM_onnx doesn't have a sys_prompt interface, so add the prompt to history history_feedin = [] history_feedin.append(["What can I do?", sys_prompt]) for i in range(len(history) // 2): history_feedin.append([history[2 * i], history[2 * i + 1]]) watch_dog_patience = 5 # Watchdog patience, set to 5 seconds response = "" for response in glm_onnx_handle.stream_chat(query=inputs, history=history_feedin): print(response) if len(observe_window) >= 1: observe_window[0] = response if len(observe_window) >= 2: if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。") return response def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None): """ 单线程方法 函数的说明请见 request_llm/bridge_all.py """ chatbot.append((inputs, "")) global glm_onnx_handle if glm_onnx_handle is None: glm_onnx_handle = GetGLMHandle() chatbot[-1] = (inputs, load_message + "\n\n" + glm_onnx_handle.info) yield from update_ui(chatbot=chatbot, history=[]) if not glm_onnx_handle.success: glm_onnx_handle = None return if additional_fn is not None: import core_functional importlib.reload(core_functional) # Hot-reload prompt core_functional = core_functional.get_core_functions() if "PreProcess" in core_functional[additional_fn]: inputs = core_functional[additional_fn]["PreProcess"](inputs) inputs = core_functional[additional_fn]["Prefix"] + inputs + core_functional[additional_fn]["Suffix"] history_feedin = [] history_feedin.append(["What can I do?", system_prompt]) for i in range(len(history) // 2): history_feedin.append([history[2 * i], history[2 * i + 1]]) response = "[Local Message]: 等待ChatGLM_onnx响应中 ..." for response in glm_onnx_handle.stream_chat(query=inputs, history=history_feedin): chatbot[-1] = (inputs, response) yield from update_ui(chatbot=chatbot, history=history) if response == "[Local Message]: 等待ChatGLM_onnx响应中 ...": response = "[Local Message]: ChatGLM_onnx响应异常 ..." history.extend([inputs, response]) yield from update_ui(chatbot=chatbot, history=history)