改进效率
This commit is contained in:
parent
1dd1720d38
commit
0aeb5b28cd
@ -10,16 +10,13 @@ def extract_code_block_carefully(txt):
|
|||||||
txt_out = '```'.join(splitted[1:-1])
|
txt_out = '```'.join(splitted[1:-1])
|
||||||
return txt_out
|
return txt_out
|
||||||
|
|
||||||
def breakdown_txt_to_satisfy_token_limit(txt, limit, must_break_at_empty_line=True):
|
def breakdown_txt_to_satisfy_token_limit(txt, get_token_fn, limit, must_break_at_empty_line=True):
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
||||||
get_token_cnt = lambda txt: len(tokenizer(txt)["input_ids"])
|
|
||||||
def cut(txt_tocut, must_break_at_empty_line): # 递归
|
def cut(txt_tocut, must_break_at_empty_line): # 递归
|
||||||
if get_token_cnt(txt_tocut) <= limit:
|
if get_token_fn(txt_tocut) <= limit:
|
||||||
return [txt_tocut]
|
return [txt_tocut]
|
||||||
else:
|
else:
|
||||||
lines = txt_tocut.split('\n')
|
lines = txt_tocut.split('\n')
|
||||||
estimated_line_cut = limit / get_token_cnt(txt_tocut) * len(lines)
|
estimated_line_cut = limit / get_token_fn(txt_tocut) * len(lines)
|
||||||
estimated_line_cut = int(estimated_line_cut)
|
estimated_line_cut = int(estimated_line_cut)
|
||||||
for cnt in reversed(range(estimated_line_cut)):
|
for cnt in reversed(range(estimated_line_cut)):
|
||||||
if must_break_at_empty_line:
|
if must_break_at_empty_line:
|
||||||
@ -27,7 +24,7 @@ def breakdown_txt_to_satisfy_token_limit(txt, limit, must_break_at_empty_line=Tr
|
|||||||
print(cnt)
|
print(cnt)
|
||||||
prev = "\n".join(lines[:cnt])
|
prev = "\n".join(lines[:cnt])
|
||||||
post = "\n".join(lines[cnt:])
|
post = "\n".join(lines[cnt:])
|
||||||
if get_token_cnt(prev) < limit: break
|
if get_token_fn(prev) < limit: break
|
||||||
if cnt == 0:
|
if cnt == 0:
|
||||||
print('what the f?')
|
print('what the f?')
|
||||||
raise RuntimeError("存在一行极长的文本!")
|
raise RuntimeError("存在一行极长的文本!")
|
||||||
@ -86,12 +83,12 @@ def 全项目切换英文(txt, top_p, temperature, chatbot, history, sys_prompt,
|
|||||||
|
|
||||||
|
|
||||||
# 第5步:Token限制下的截断与处理
|
# 第5步:Token限制下的截断与处理
|
||||||
MAX_TOKEN = 2500
|
MAX_TOKEN = 3000
|
||||||
# from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
# print('加载tokenizer中')
|
print('加载tokenizer中')
|
||||||
# tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
# get_token_cnt = lambda txt: len(tokenizer(txt)["input_ids"])
|
get_token_fn = lambda txt: len(tokenizer(txt)["input_ids"])
|
||||||
# print('加载tokenizer结束')
|
print('加载tokenizer结束')
|
||||||
|
|
||||||
|
|
||||||
# 第6步:任务函数
|
# 第6步:任务函数
|
||||||
@ -107,7 +104,7 @@ def 全项目切换英文(txt, top_p, temperature, chatbot, history, sys_prompt,
|
|||||||
try:
|
try:
|
||||||
gpt_say = ""
|
gpt_say = ""
|
||||||
# 分解代码文件
|
# 分解代码文件
|
||||||
file_content_breakdown = breakdown_txt_to_satisfy_token_limit(file_content, MAX_TOKEN)
|
file_content_breakdown = breakdown_txt_to_satisfy_token_limit(file_content, get_token_fn, MAX_TOKEN)
|
||||||
for file_content_partial in file_content_breakdown:
|
for file_content_partial in file_content_breakdown:
|
||||||
i_say = i_say_template(fp, file_content_partial)
|
i_say = i_say_template(fp, file_content_partial)
|
||||||
# # ** gpt request **
|
# # ** gpt request **
|
||||||
|
Loading…
x
Reference in New Issue
Block a user