改进效率

This commit is contained in:
qingxu fu 2023-04-05 00:25:53 +08:00
parent 1dd1720d38
commit 0aeb5b28cd

View File

@ -10,16 +10,13 @@ def extract_code_block_carefully(txt):
txt_out = '```'.join(splitted[1:-1]) txt_out = '```'.join(splitted[1:-1])
return txt_out return txt_out
def breakdown_txt_to_satisfy_token_limit(txt, limit, must_break_at_empty_line=True): def breakdown_txt_to_satisfy_token_limit(txt, get_token_fn, limit, must_break_at_empty_line=True):
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
get_token_cnt = lambda txt: len(tokenizer(txt)["input_ids"])
def cut(txt_tocut, must_break_at_empty_line): # 递归 def cut(txt_tocut, must_break_at_empty_line): # 递归
if get_token_cnt(txt_tocut) <= limit: if get_token_fn(txt_tocut) <= limit:
return [txt_tocut] return [txt_tocut]
else: else:
lines = txt_tocut.split('\n') lines = txt_tocut.split('\n')
estimated_line_cut = limit / get_token_cnt(txt_tocut) * len(lines) estimated_line_cut = limit / get_token_fn(txt_tocut) * len(lines)
estimated_line_cut = int(estimated_line_cut) estimated_line_cut = int(estimated_line_cut)
for cnt in reversed(range(estimated_line_cut)): for cnt in reversed(range(estimated_line_cut)):
if must_break_at_empty_line: if must_break_at_empty_line:
@ -27,7 +24,7 @@ def breakdown_txt_to_satisfy_token_limit(txt, limit, must_break_at_empty_line=Tr
print(cnt) print(cnt)
prev = "\n".join(lines[:cnt]) prev = "\n".join(lines[:cnt])
post = "\n".join(lines[cnt:]) post = "\n".join(lines[cnt:])
if get_token_cnt(prev) < limit: break if get_token_fn(prev) < limit: break
if cnt == 0: if cnt == 0:
print('what the f?') print('what the f?')
raise RuntimeError("存在一行极长的文本!") raise RuntimeError("存在一行极长的文本!")
@ -86,12 +83,12 @@ def 全项目切换英文(txt, top_p, temperature, chatbot, history, sys_prompt,
# 第5步Token限制下的截断与处理 # 第5步Token限制下的截断与处理
MAX_TOKEN = 2500 MAX_TOKEN = 3000
# from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
# print('加载tokenizer中') print('加载tokenizer中')
# tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# get_token_cnt = lambda txt: len(tokenizer(txt)["input_ids"]) get_token_fn = lambda txt: len(tokenizer(txt)["input_ids"])
# print('加载tokenizer结束') print('加载tokenizer结束')
# 第6步任务函数 # 第6步任务函数
@ -107,7 +104,7 @@ def 全项目切换英文(txt, top_p, temperature, chatbot, history, sys_prompt,
try: try:
gpt_say = "" gpt_say = ""
# 分解代码文件 # 分解代码文件
file_content_breakdown = breakdown_txt_to_satisfy_token_limit(file_content, MAX_TOKEN) file_content_breakdown = breakdown_txt_to_satisfy_token_limit(file_content, get_token_fn, MAX_TOKEN)
for file_content_partial in file_content_breakdown: for file_content_partial in file_content_breakdown:
i_say = i_say_template(fp, file_content_partial) i_say = i_say_template(fp, file_content_partial)
# # ** gpt request ** # # ** gpt request **