From 5813d65e523545209f89df170c9b316a0e66dc46 Mon Sep 17 00:00:00 2001 From: fenglui Date: Sat, 22 Jul 2023 08:29:15 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0chatGLM=20int4=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=94=AF=E6=8C=81=20=E5=B0=8F=E6=98=BE=E5=AD=98?= =?UTF-8?q?=E4=B9=9F=E5=8F=AF=E4=BB=A5=E9=80=89=E6=8B=A9chatGLM?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 1 + request_llm/bridge_chatglm.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/config.py b/config.py index 553ae70..f39fef3 100644 --- a/config.py +++ b/config.py @@ -80,6 +80,7 @@ ChatGLM_PTUNING_CHECKPOINT = "" # 例如"/home/hmp/ChatGLM2-6B/ptuning/output/6b # 本地LLM模型如ChatGLM的执行方式 CPU/GPU LOCAL_MODEL_DEVICE = "cpu" # 可选 "cuda" +LOCAL_MODEL_QUANT = "INT4" # 默认 "" "INT4" 启用量化INT4版本 "INT8" 启用量化INT8版本 # 设置gradio的并行线程数(不需要修改) diff --git a/request_llm/bridge_chatglm.py b/request_llm/bridge_chatglm.py index deaacd2..c7ec42b 100644 --- a/request_llm/bridge_chatglm.py +++ b/request_llm/bridge_chatglm.py @@ -37,15 +37,19 @@ class GetGLMHandle(Process): # 子进程执行 # 第一次运行,加载参数 retry = 0 + pretrained_model_name_or_path = "THUDM/chatglm2-6b" + LOCAL_MODEL_QUANT = get_conf('LOCAL_MODEL_QUANT') + if LOCAL_MODEL_QUANT and len(LOCAL_MODEL_QUANT) > 0 and LOCAL_MODEL_QUANT[0] == "INT4": + pretrained_model_name_or_path = "THUDM/chatglm2-6b-int4" while True: try: if self.chatglm_model is None: - self.chatglm_tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) + self.chatglm_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) device, = get_conf('LOCAL_MODEL_DEVICE') if device=='cpu': - self.chatglm_model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).float() + self.chatglm_model = AutoModel.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True).float() else: - self.chatglm_model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).half().cuda() + self.chatglm_model = AutoModel.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True).half().cuda() self.chatglm_model = self.chatglm_model.eval() break else: