From c010d50716683d6a61d884ca3d538e9efb8370ff Mon Sep 17 00:00:00 2001 From: binary-husky Date: Mon, 10 Jul 2023 03:17:09 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=81=E8=AE=B8=E5=8A=A0=E5=85=A5ChatGLM?= =?UTF-8?q?=E5=BE=AE=E8=B0=83=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 4 + crazy_functions/chatglm微调工具.py | 54 ++++++++ request_llm/bridge_all.py | 20 ++- request_llm/bridge_chatglmft.py | 210 +++++++++++++++++++++++++++++ 4 files changed, 287 insertions(+), 1 deletion(-) create mode 100644 request_llm/bridge_chatglmft.py diff --git a/config.py b/config.py index 621e575..8e624e6 100644 --- a/config.py +++ b/config.py @@ -74,6 +74,10 @@ AVAIL_LLM_MODELS = ["gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "api2 # P.S. 其他可用的模型还包括 ["gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "newbing-free", "jittorllms_rwkv", "jittorllms_pangualpha", "jittorllms_llama"] +# ChatGLM(2) Finetune Model Path (如果使用ChatGLM2微调模型,需要把"chatglmft"加入AVAIL_LLM_MODELS中) +ChatGLM_PTUNING_CHECKPOINT = "" # 例如"/home/hmp/ChatGLM2-6B/ptuning/output/6b-pt-128-1e-2/checkpoint-100" + + # 本地LLM模型如ChatGLM的执行方式 CPU/GPU LOCAL_MODEL_DEVICE = "cpu" # 可选 "cuda" diff --git a/crazy_functions/chatglm微调工具.py b/crazy_functions/chatglm微调工具.py index 58a9208..0c8f5d2 100644 --- a/crazy_functions/chatglm微调工具.py +++ b/crazy_functions/chatglm微调工具.py @@ -69,3 +69,57 @@ def 微调数据集生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst promote_file_to_downloadzone(txt+'.generated.json', rename_file='generated.json', chatbot=chatbot) return + + + +def 启动微调(arguments): + """ + txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径 + llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行 + plugin_kwargs 插件模型的参数 + chatbot 聊天显示框的句柄,用于显示给用户 + history 聊天历史,前情提要 + system_prompt 给gpt的静默提醒 + web_port 当前软件运行的端口号 + """ + history = [] # 清空历史,以免输入溢出 + import subprocess + PRE_SEQ_LEN = 128 + LR = 2e-2 + NUM_GPUS = 1 + JSON_FILE = 't_code.json' + tune_work_path = '/home/hmp/ChatGLM2-6B/ptuning' + + command = f"torchrun --standalone --nnodes=1 --nproc-per-node={NUM_GPUS} main.py \ + --do_train \ + --train_file AdvertiseGen/{JSON_FILE} \ + --validation_file AdvertiseGen/{JSON_FILE} \ + --preprocessing_num_workers 20 \ + --prompt_column content \ + --response_column summary \ + --overwrite_cache \ + --model_name_or_path THUDM/chatglm2-6b \ + --output_dir output/clothgen-chatglm2-6b-pt-{PRE_SEQ_LEN}-{LR} \ + --overwrite_output_dir \ + --max_source_length 256 \ + --max_target_length 256 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --predict_with_generate \ + --max_steps 100 \ + --logging_steps 10 \ + --save_steps 20 \ + --learning_rate {LR} \ + --pre_seq_len {PRE_SEQ_LEN} \ + --quantization_bit 4" + + process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=tune_work_path) + try: + stdout, stderr = process.communicate(timeout=3600*5) + except subprocess.TimeoutExpired: + process.kill() + stdout, stderr = process.communicate() + print("Process timed out!") + return False + return diff --git a/request_llm/bridge_all.py b/request_llm/bridge_all.py index 13f49bd..ed9ceb0 100644 --- a/request_llm/bridge_all.py +++ b/request_llm/bridge_all.py @@ -269,6 +269,24 @@ if "newbing" in AVAIL_LLM_MODELS: # same with newbing-free }) except: print(trimmed_format_exc()) +if "chatglmft" in AVAIL_LLM_MODELS: # same with newbing-free + try: + from .bridge_chatglmft import predict_no_ui_long_connection as chatglmft_noui + from .bridge_chatglmft import predict as chatglmft_ui + # claude + model_info.update({ + "chatglmft": { + "fn_with_ui": chatglmft_ui, + "fn_without_ui": chatglmft_noui, + "endpoint": None, + "max_token": 4096, + "tokenizer": tokenizer_gpt35, + "token_cnt": get_token_num_gpt35, + } + }) + except: + print(trimmed_format_exc()) + def LLM_CATCH_EXCEPTION(f): """ @@ -372,6 +390,6 @@ def predict(inputs, llm_kwargs, *args, **kwargs): additional_fn代表点击的哪个按钮,按钮见functional.py """ - method = model_info[llm_kwargs['llm_model']]["fn_with_ui"] + method = model_info[llm_kwargs['llm_model']]["fn_with_ui"] # 如果这里报错,检查config中的AVAIL_LLM_MODELS选项 yield from method(inputs, llm_kwargs, *args, **kwargs) diff --git a/request_llm/bridge_chatglmft.py b/request_llm/bridge_chatglmft.py new file mode 100644 index 0000000..cd0ccb9 --- /dev/null +++ b/request_llm/bridge_chatglmft.py @@ -0,0 +1,210 @@ + +from transformers import AutoModel, AutoTokenizer +import time +import os +import json +import threading +import importlib +from toolbox import update_ui, get_conf +from multiprocessing import Process, Pipe + +load_message = "ChatGLMFT尚未加载,加载需要一段时间。注意,取决于`config.py`的配置,ChatGLMFT消耗大量的内存(CPU)或显存(GPU),也许会导致低配计算机卡死 ……" + +def string_to_options(arguments): + import argparse + import shlex + # Create an argparse.ArgumentParser instance + parser = argparse.ArgumentParser() + # Add command-line arguments + parser.add_argument("--llm_to_learn", type=str, help="LLM model to learn", default="gpt-3.5-turbo") + parser.add_argument("--prompt_prefix", type=str, help="Prompt prefix", default='') + parser.add_argument("--system_prompt", type=str, help="System prompt", default='') + parser.add_argument("--batch", type=int, help="System prompt", default=50) + # Parse the arguments + args = parser.parse_args(shlex.split(arguments)) + return args + + +################################################################################# +class GetGLMFTHandle(Process): + def __init__(self): + super().__init__(daemon=True) + self.parent, self.child = Pipe() + self.chatglmft_model = None + self.chatglmft_tokenizer = None + self.info = "" + self.success = True + self.check_dependency() + self.start() + self.threadLock = threading.Lock() + + def check_dependency(self): + try: + import sentencepiece + self.info = "依赖检测通过" + self.success = True + except: + self.info = "缺少ChatGLMFT的依赖,如果要使用ChatGLMFT,除了基础的pip依赖以外,您还需要运行`pip install -r request_llm/requirements_chatglm.txt`安装ChatGLM的依赖。" + self.success = False + + def ready(self): + return self.chatglmft_model is not None + + def run(self): + # 子进程执行 + # 第一次运行,加载参数 + retry = 0 + while True: + try: + if self.chatglmft_model is None: + from transformers import AutoConfig + import torch + # conf = 'request_llm/current_ptune_model.json' + # if not os.path.exists(conf): raise RuntimeError('找不到微调模型信息') + # with open(conf, 'r', encoding='utf8') as f: + # model_args = json.loads(f.read()) + ChatGLM_PTUNING_CHECKPOINT, = get_conf('ChatGLM_PTUNING_CHECKPOINT') + assert os.path.exists(ChatGLM_PTUNING_CHECKPOINT), "找不到微调模型检查点" + conf = os.path.join(ChatGLM_PTUNING_CHECKPOINT, "config.json") + with open(conf, 'r', encoding='utf8') as f: + model_args = json.loads(f.read()) + if 'model_name_or_path' not in model_args: + model_args['model_name_or_path'] = model_args['_name_or_path'] + self.chatglmft_tokenizer = AutoTokenizer.from_pretrained( + model_args['model_name_or_path'], trust_remote_code=True) + config = AutoConfig.from_pretrained( + model_args['model_name_or_path'], trust_remote_code=True) + + config.pre_seq_len = model_args['pre_seq_len'] + config.prefix_projection = model_args['prefix_projection'] + + print(f"Loading prefix_encoder weight from {ChatGLM_PTUNING_CHECKPOINT}") + model = AutoModel.from_pretrained(model_args['model_name_or_path'], config=config, trust_remote_code=True) + prefix_state_dict = torch.load(os.path.join(ChatGLM_PTUNING_CHECKPOINT, "pytorch_model.bin")) + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v + model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + + if model_args['quantization_bit'] is not None: + print(f"Quantized to {model_args['quantization_bit']} bit") + model = model.quantize(model_args['quantization_bit']) + model = model.cuda() + if model_args['pre_seq_len'] is not None: + # P-tuning v2 + model.transformer.prefix_encoder.float() + self.chatglmft_model = model.eval() + + break + else: + break + except Exception as e: + retry += 1 + if retry > 3: + self.child.send('[Local Message] Call ChatGLMFT fail 不能正常加载ChatGLMFT的参数。') + raise RuntimeError("不能正常加载ChatGLMFT的参数!") + + while True: + # 进入任务等待状态 + kwargs = self.child.recv() + # 收到消息,开始请求 + try: + for response, history in self.chatglmft_model.stream_chat(self.chatglmft_tokenizer, **kwargs): + self.child.send(response) + # # 中途接收可能的终止指令(如果有的话) + # if self.child.poll(): + # command = self.child.recv() + # if command == '[Terminate]': break + except: + from toolbox import trimmed_format_exc + self.child.send('[Local Message] Call ChatGLMFT fail.' + '\n```\n' + trimmed_format_exc() + '\n```\n') + # 请求处理结束,开始下一个循环 + self.child.send('[Finish]') + + def stream_chat(self, **kwargs): + # 主进程执行 + self.threadLock.acquire() + self.parent.send(kwargs) + while True: + res = self.parent.recv() + if res != '[Finish]': + yield res + else: + break + self.threadLock.release() + +global glmft_handle +glmft_handle = None +################################################################################# +def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False): + """ + 多线程方法 + 函数的说明请见 request_llm/bridge_all.py + """ + global glmft_handle + if glmft_handle is None: + glmft_handle = GetGLMFTHandle() + if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + glmft_handle.info + if not glmft_handle.success: + error = glmft_handle.info + glmft_handle = None + raise RuntimeError(error) + + # chatglmft 没有 sys_prompt 接口,因此把prompt加入 history + history_feedin = [] + history_feedin.append(["What can I do?", sys_prompt]) + for i in range(len(history)//2): + history_feedin.append([history[2*i], history[2*i+1]] ) + + watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可 + response = "" + for response in glmft_handle.stream_chat(query=inputs, history=history_feedin, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']): + if len(observe_window) >= 1: observe_window[0] = response + if len(observe_window) >= 2: + if (time.time()-observe_window[1]) > watch_dog_patience: + raise RuntimeError("程序终止。") + return response + + + +def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None): + """ + 单线程方法 + 函数的说明请见 request_llm/bridge_all.py + """ + chatbot.append((inputs, "")) + + global glmft_handle + if glmft_handle is None: + glmft_handle = GetGLMFTHandle() + chatbot[-1] = (inputs, load_message + "\n\n" + glmft_handle.info) + yield from update_ui(chatbot=chatbot, history=[]) + if not glmft_handle.success: + glmft_handle = None + return + + if additional_fn is not None: + import core_functional + importlib.reload(core_functional) # 热更新prompt + core_functional = core_functional.get_core_functions() + if "PreProcess" in core_functional[additional_fn]: inputs = core_functional[additional_fn]["PreProcess"](inputs) # 获取预处理函数(如果有的话) + inputs = core_functional[additional_fn]["Prefix"] + inputs + core_functional[additional_fn]["Suffix"] + + # 处理历史信息 + history_feedin = [] + history_feedin.append(["What can I do?", system_prompt] ) + for i in range(len(history)//2): + history_feedin.append([history[2*i], history[2*i+1]] ) + + # 开始接收chatglmft的回复 + response = "[Local Message]: 等待ChatGLMFT响应中 ..." + for response in glmft_handle.stream_chat(query=inputs, history=history_feedin, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']): + chatbot[-1] = (inputs, response) + yield from update_ui(chatbot=chatbot, history=history) + + # 总结输出 + if response == "[Local Message]: 等待ChatGLMFT响应中 ...": + response = "[Local Message]: ChatGLMFT响应异常 ..." + history.extend([inputs, response]) + yield from update_ui(chatbot=chatbot, history=history)