diff --git a/crazy_functions/chatglm微调工具.py b/crazy_functions/chatglm微调工具.py index 0c8f5d2..336d7cf 100644 --- a/crazy_functions/chatglm微调工具.py +++ b/crazy_functions/chatglm微调工具.py @@ -18,6 +18,13 @@ def string_to_options(arguments): parser.add_argument("--prompt_prefix", type=str, help="Prompt prefix", default='') parser.add_argument("--system_prompt", type=str, help="System prompt", default='') parser.add_argument("--batch", type=int, help="System prompt", default=50) + parser.add_argument("--pre_seq_len", type=int, help="pre_seq_len", default=50) + parser.add_argument("--learning_rate", type=float, help="learning_rate", default=2e-2) + parser.add_argument("--num_gpus", type=int, help="num_gpus", default=1) + parser.add_argument("--json_dataset", type=str, help="json_dataset", default="") + parser.add_argument("--ptuning_directory", type=str, help="ptuning_directory", default="") + + # Parse the arguments args = parser.parse_args(shlex.split(arguments)) @@ -72,7 +79,8 @@ def 微调数据集生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst -def 启动微调(arguments): +@CatchException +def 启动微调(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): """ txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径 llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行 @@ -82,24 +90,35 @@ def 启动微调(arguments): system_prompt 给gpt的静默提醒 web_port 当前软件运行的端口号 """ - history = [] # 清空历史,以免输入溢出 import subprocess - PRE_SEQ_LEN = 128 - LR = 2e-2 - NUM_GPUS = 1 - JSON_FILE = 't_code.json' - tune_work_path = '/home/hmp/ChatGLM2-6B/ptuning' + history = [] # 清空历史,以免输入溢出 + chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成")) + if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg") + args = plugin_kwargs.get("advanced_arg", None) + if args is None: + chatbot.append(("没给定指令", "退出")) + yield from update_ui(chatbot=chatbot, history=history); return + else: + arguments = string_to_options(arguments=args) + - command = f"torchrun --standalone --nnodes=1 --nproc-per-node={NUM_GPUS} main.py \ + + pre_seq_len = arguments.pre_seq_len # 128 + learning_rate = arguments.learning_rate # 2e-2 + num_gpus = arguments.num_gpus # 1 + json_dataset = arguments.json_dataset # 't_code.json' + ptuning_directory = arguments.ptuning_directory # '/home/hmp/ChatGLM2-6B/ptuning' + + command = f"torchrun --standalone --nnodes=1 --nproc-per-node={num_gpus} main.py \ --do_train \ - --train_file AdvertiseGen/{JSON_FILE} \ - --validation_file AdvertiseGen/{JSON_FILE} \ + --train_file AdvertiseGen/{json_dataset} \ + --validation_file AdvertiseGen/{json_dataset} \ --preprocessing_num_workers 20 \ --prompt_column content \ --response_column summary \ --overwrite_cache \ --model_name_or_path THUDM/chatglm2-6b \ - --output_dir output/clothgen-chatglm2-6b-pt-{PRE_SEQ_LEN}-{LR} \ + --output_dir output/clothgen-chatglm2-6b-pt-{pre_seq_len}-{learning_rate} \ --overwrite_output_dir \ --max_source_length 256 \ --max_target_length 256 \ @@ -110,16 +129,13 @@ def 启动微调(arguments): --max_steps 100 \ --logging_steps 10 \ --save_steps 20 \ - --learning_rate {LR} \ - --pre_seq_len {PRE_SEQ_LEN} \ + --learning_rate {learning_rate} \ + --pre_seq_len {pre_seq_len} \ --quantization_bit 4" - process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=tune_work_path) + process = subprocess.Popen(command, shell=True, cwd=ptuning_directory) try: - stdout, stderr = process.communicate(timeout=3600*5) + process.communicate(timeout=3600*24) except subprocess.TimeoutExpired: process.kill() - stdout, stderr = process.communicate() - print("Process timed out!") - return False return diff --git a/crazy_functions/crazy_functions_test.py b/crazy_functions/crazy_functions_test.py index a614aac..8b6b540 100644 --- a/crazy_functions/crazy_functions_test.py +++ b/crazy_functions/crazy_functions_test.py @@ -212,11 +212,17 @@ def test_Latex(): # cli_printer.print(cb) # print(cb) def test_chatglm_finetune(): - from crazy_functions.chatglm微调工具 import 微调数据集生成 + from crazy_functions.chatglm微调工具 import 微调数据集生成, 启动微调 txt = 'build/dev.json' plugin_kwargs = {"advanced_arg":"--llm_to_learn=gpt-3.5-turbo --prompt_prefix='根据下面的服装类型提示,想象一个穿着者,对这个人外貌、身处的环境、内心世界、人设进行描写。要求:100字以内,用第二人称。' --system_prompt=''" } - for cookies, cb, hist, msg in (微调数据集生成)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): + # for cookies, cb, hist, msg in (微调数据集生成)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): + # cli_printer.print(cb) + + plugin_kwargs = {"advanced_arg": + " --pre_seq_len=128 --learning_rate=2e-2 --num_gpus=1 --json_dataset='t_code.json' --ptuning_directory='/home/hmp/ChatGLM2-6B/ptuning' " } + + for cookies, cb, hist, msg in (启动微调)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): cli_printer.print(cb)