From 0aeb5b28cd48579db983db5460c624ce7ace0182 Mon Sep 17 00:00:00 2001 From: qingxu fu <505030475@qq.com> Date: Wed, 5 Apr 2023 00:25:53 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E8=BF=9B=E6=95=88=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crazy_functions/代码重写为全英文_多线程.py | 25 ++++++++++------------ 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/crazy_functions/代码重写为全英文_多线程.py b/crazy_functions/代码重写为全英文_多线程.py index bfcbec3..7f62088 100644 --- a/crazy_functions/代码重写为全英文_多线程.py +++ b/crazy_functions/代码重写为全英文_多线程.py @@ -10,16 +10,13 @@ def extract_code_block_carefully(txt): txt_out = '```'.join(splitted[1:-1]) return txt_out -def breakdown_txt_to_satisfy_token_limit(txt, limit, must_break_at_empty_line=True): - from transformers import GPT2TokenizerFast - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") - get_token_cnt = lambda txt: len(tokenizer(txt)["input_ids"]) +def breakdown_txt_to_satisfy_token_limit(txt, get_token_fn, limit, must_break_at_empty_line=True): def cut(txt_tocut, must_break_at_empty_line): # 递归 - if get_token_cnt(txt_tocut) <= limit: + if get_token_fn(txt_tocut) <= limit: return [txt_tocut] else: lines = txt_tocut.split('\n') - estimated_line_cut = limit / get_token_cnt(txt_tocut) * len(lines) + estimated_line_cut = limit / get_token_fn(txt_tocut) * len(lines) estimated_line_cut = int(estimated_line_cut) for cnt in reversed(range(estimated_line_cut)): if must_break_at_empty_line: @@ -27,7 +24,7 @@ def breakdown_txt_to_satisfy_token_limit(txt, limit, must_break_at_empty_line=Tr print(cnt) prev = "\n".join(lines[:cnt]) post = "\n".join(lines[cnt:]) - if get_token_cnt(prev) < limit: break + if get_token_fn(prev) < limit: break if cnt == 0: print('what the f?') raise RuntimeError("存在一行极长的文本!") @@ -86,12 +83,12 @@ def 全项目切换英文(txt, top_p, temperature, chatbot, history, sys_prompt, # 第5步:Token限制下的截断与处理 - MAX_TOKEN = 2500 - # from transformers import GPT2TokenizerFast - # print('加载tokenizer中') - # tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") - # get_token_cnt = lambda txt: len(tokenizer(txt)["input_ids"]) - # print('加载tokenizer结束') + MAX_TOKEN = 3000 + from transformers import GPT2TokenizerFast + print('加载tokenizer中') + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + get_token_fn = lambda txt: len(tokenizer(txt)["input_ids"]) + print('加载tokenizer结束') # 第6步:任务函数 @@ -107,7 +104,7 @@ def 全项目切换英文(txt, top_p, temperature, chatbot, history, sys_prompt, try: gpt_say = "" # 分解代码文件 - file_content_breakdown = breakdown_txt_to_satisfy_token_limit(file_content, MAX_TOKEN) + file_content_breakdown = breakdown_txt_to_satisfy_token_limit(file_content, get_token_fn, MAX_TOKEN) for file_content_partial in file_content_breakdown: i_say = i_say_template(fp, file_content_partial) # # ** gpt request **