66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
|
|
|
|
|
|
def breakdown_txt_to_satisfy_token_limit(txt, get_token_fn, limit):
|
|
def cut(txt_tocut, must_break_at_empty_line): # 递归
|
|
if get_token_fn(txt_tocut) <= limit:
|
|
return [txt_tocut]
|
|
else:
|
|
lines = txt_tocut.split('\n')
|
|
estimated_line_cut = limit / get_token_fn(txt_tocut) * len(lines)
|
|
estimated_line_cut = int(estimated_line_cut)
|
|
for cnt in reversed(range(estimated_line_cut)):
|
|
if must_break_at_empty_line:
|
|
if lines[cnt] != "": continue
|
|
print(cnt)
|
|
prev = "\n".join(lines[:cnt])
|
|
post = "\n".join(lines[cnt:])
|
|
if get_token_fn(prev) < limit: break
|
|
if cnt == 0:
|
|
print('what the fuck ?')
|
|
raise RuntimeError("存在一行极长的文本!")
|
|
# print(len(post))
|
|
# 列表递归接龙
|
|
result = [prev]
|
|
result.extend(cut(post, must_break_at_empty_line))
|
|
return result
|
|
try:
|
|
return cut(txt, must_break_at_empty_line=True)
|
|
except RuntimeError:
|
|
return cut(txt, must_break_at_empty_line=False)
|
|
|
|
def breakdown_txt_to_satisfy_token_limit_for_pdf(txt, get_token_fn, limit):
|
|
def cut(txt_tocut, must_break_at_empty_line): # 递归
|
|
if get_token_fn(txt_tocut) <= limit:
|
|
return [txt_tocut]
|
|
else:
|
|
lines = txt_tocut.split('\n')
|
|
estimated_line_cut = limit / get_token_fn(txt_tocut) * len(lines)
|
|
estimated_line_cut = int(estimated_line_cut)
|
|
cnt = 0
|
|
for cnt in reversed(range(estimated_line_cut)):
|
|
if must_break_at_empty_line:
|
|
if lines[cnt] != "": continue
|
|
print(cnt)
|
|
prev = "\n".join(lines[:cnt])
|
|
post = "\n".join(lines[cnt:])
|
|
if get_token_fn(prev) < limit: break
|
|
if cnt == 0:
|
|
# print('what the fuck ? 存在一行极长的文本!')
|
|
raise RuntimeError("存在一行极长的文本!")
|
|
# print(len(post))
|
|
# 列表递归接龙
|
|
result = [prev]
|
|
result.extend(cut(post, must_break_at_empty_line))
|
|
return result
|
|
try:
|
|
return cut(txt, must_break_at_empty_line=True)
|
|
except RuntimeError:
|
|
try:
|
|
return cut(txt, must_break_at_empty_line=False)
|
|
except RuntimeError:
|
|
# 这个中文的句号是故意的,作为一个标识而存在
|
|
res = cut(txt.replace('.', '。\n'), must_break_at_empty_line=False)
|
|
return [r.replace('。\n', '.') for r in res]
|
|
|