improve success rate

This commit is contained in:
qingxu fu 2023-06-03 19:39:19 +08:00
parent 68fea9e79b
commit 70ee810133

View File

@ -5,6 +5,10 @@ import re
pj = os.path.join pj = os.path.join
def 寻找Latex主文件(file_manifest, mode): def 寻找Latex主文件(file_manifest, mode):
"""
在多Tex文档中寻找主文件必须包含documentclass返回找到的第一个
P.S. 但愿没人把latex模板放在里面传进来
"""
for texf in file_manifest: for texf in file_manifest:
if os.path.basename(texf).startswith('merge'): if os.path.basename(texf).startswith('merge'):
continue continue
@ -17,6 +21,9 @@ def 寻找Latex主文件(file_manifest, mode):
raise RuntimeError('无法找到一个主Tex文件包含documentclass关键字') raise RuntimeError('无法找到一个主Tex文件包含documentclass关键字')
def merge_tex_files_(project_foler, main_file, mode): def merge_tex_files_(project_foler, main_file, mode):
"""
递归地把多Tex工程整合为一个Tex文档
"""
for s in reversed([q for q in re.finditer(r"\\input\{(.*?)\}", main_file, re.M)]): for s in reversed([q for q in re.finditer(r"\\input\{(.*?)\}", main_file, re.M)]):
f = s.group(1) f = s.group(1)
fp = os.path.join(project_foler, f) fp = os.path.join(project_foler, f)
@ -33,38 +40,56 @@ def merge_tex_files_(project_foler, main_file, mode):
return main_file return main_file
def merge_tex_files(project_foler, main_file, mode): def merge_tex_files(project_foler, main_file, mode):
"""
递归地把多Tex工程整合为一个Tex文档递归外层
P.S. 顺便把CTEX塞进去以支持中文
P.S. 顺便把Latex的注释去除
"""
main_file = merge_tex_files_(project_foler, main_file, mode) main_file = merge_tex_files_(project_foler, main_file, mode)
if mode == 'translate_zh': if mode == 'translate_zh':
pattern = re.compile(r'\\documentclass.*\n') pattern = re.compile(r'\\documentclass.*\n')
match = pattern.search(main_file) match = pattern.search(main_file)
position = match.end() position = match.end()
main_file = main_file[:position] + '\\usepackage{CTEX}\n\\usepackage{url}\n' + main_file[position:] main_file = main_file[:position] + '\\usepackage{CTEX}\n\\usepackage{url}\n' + main_file[position:]
new_file_remove_comment_lines = []
for l in main_file.splitlines():
# 删除整行的空注释
if l.startswith("%") or (l.startswith(" ") and l.lstrip().startswith("%")):
pass
else:
new_file_remove_comment_lines.append(l)
main_file = '\n'.join(new_file_remove_comment_lines)
main_file = re.sub(r'(?<!\\)%.*', '', main_file) # 使用正则表达式查找半行注释, 并替换为空字符串
return main_file return main_file
class LinkTable():
class LinkedListNode():
"""
链表单元
"""
def __init__(self, string, preserve=True) -> None: def __init__(self, string, preserve=True) -> None:
self.string = string self.string = string
self.preserve = preserve self.preserve = preserve
self.next = None self.next = None
def mod_inbraket(match): def mod_inbraket(match):
"""
为啥chatgpt会把cite里面的逗号换成中文逗号呀
"""
# get the matched string # get the matched string
cmd = match.group(1) cmd = match.group(1)
str_to_modify = match.group(2) str_to_modify = match.group(2)
# modify the matched string # modify the matched string
str_to_modify = str_to_modify.replace('', ':') # 前面是中文冒号,后面是英文冒号 str_to_modify = str_to_modify.replace('', ':') # 前面是中文冒号,后面是英文冒号
str_to_modify = str_to_modify.replace('', ',') # 前面是中文逗号,后面是英文逗号 str_to_modify = str_to_modify.replace('', ',') # 前面是中文逗号,后面是英文逗号
# str_to_modify = 'BOOM' # str_to_modify = 'BOOM'
# return the modified string as the replacement
return "\\" + cmd + "{" + str_to_modify + "}" return "\\" + cmd + "{" + str_to_modify + "}"
def fix_content(final_tex, node_string): def fix_content(final_tex, node_string):
""" """
fix common GPT errors to increase success rate Fix common GPT errors to increase success rate
""" """
final_tex = final_tex.replace('%', r'\%') final_tex = final_tex.replace('%', r'\%')
final_tex = final_tex.replace(r'\%', r'\\%') final_tex = final_tex.replace(r'\%', r'\\%')
@ -74,10 +99,19 @@ def fix_content(final_tex, node_string):
return final_tex return final_tex
class LatexPaperSplit(): class LatexPaperSplit():
"""
将Latex文档分解到一个链表中每个链表节点用preserve的标志位提示它是否应当被GPT处理
"""
def __init__(self) -> None: def __init__(self) -> None:
"""
root是链表的根节点
"""
self.root = None self.root = None
def merge_result(self, arr, mode, msg): def merge_result(self, arr, mode, msg):
"""
将GPT处理后的结果融合
"""
result_string = "" result_string = ""
node = self.root node = self.root
p = 0 p = 0
@ -105,8 +139,10 @@ class LatexPaperSplit():
return result_string return result_string
def split(self, txt): def split(self, txt):
# def replace_with_hash() """
root = LinkTable(txt, False) 将Latex文档分解到一个链表中每个链表节点用preserve的标志位提示它是否应当被GPT处理
"""
root = LinkedListNode(txt, False)
def split_worker(root, pattern, flags=0): def split_worker(root, pattern, flags=0):
lt = root lt = root
cnt = 0 cnt = 0
@ -131,10 +167,10 @@ class LatexPaperSplit():
lt.string = before lt.string = before
tmp = lt.next tmp = lt.next
# ====== # ======
mid = LinkTable(this, True) mid = LinkedListNode(this, True)
lt.next = mid lt.next = mid
# ====== # ======
aft = LinkTable(after, False) aft = LinkedListNode(after, False)
mid.next = aft mid.next = aft
aft.next = tmp aft.next = tmp
# ====== # ======
@ -152,6 +188,8 @@ class LatexPaperSplit():
split_worker(root, r"\\subsubsection\{(.*?)\}") split_worker(root, r"\\subsubsection\{(.*?)\}")
split_worker(root, r"\\bibliography\{(.*?)\}") split_worker(root, r"\\bibliography\{(.*?)\}")
split_worker(root, r"\\bibliographystyle\{(.*?)\}") split_worker(root, r"\\bibliographystyle\{(.*?)\}")
split_worker(root, r"\\begin\{lstlisting\}(.*?)\\end\{lstlisting\}", re.DOTALL)
split_worker(root, r"\\begin\{algorithm\}(.*?)\\end\{algorithm\}", re.DOTALL)
split_worker(root, r"\\begin\{wrapfigure\}(.*?)\\end\{wrapfigure\}", re.DOTALL) split_worker(root, r"\\begin\{wrapfigure\}(.*?)\\end\{wrapfigure\}", re.DOTALL)
split_worker(root, r"\\begin\{wrapfigure\*\}(.*?)\\end\{wrapfigure\*\}", re.DOTALL) split_worker(root, r"\\begin\{wrapfigure\*\}(.*?)\\end\{wrapfigure\*\}", re.DOTALL)
split_worker(root, r"\\begin\{figure\}(.*?)\\end\{figure\}", re.DOTALL) split_worker(root, r"\\begin\{figure\}(.*?)\\end\{figure\}", re.DOTALL)
@ -178,13 +216,17 @@ class LatexPaperSplit():
node = node.next node = node.next
if node is None: break if node is None: break
with open('debug_log', 'w', encoding='utf8') as f: # 将分解结果返回 res_to_t
with open('debug_log.html', 'w', encoding='utf8') as f:
res_to_t = [] res_to_t = []
node = root node = root
while True: while True:
show_html = node.string.replace('\n','<br/>')
if not node.preserve: if not node.preserve:
res_to_t.append(node.string) res_to_t.append(node.string)
f.write(node.string) f.write(f'<p style="color:black;">{show_html}</p>')
else:
f.write(f'<p style="color:red;">{show_html}</p>')
node = node.next node = node.next
if node is None: break if node is None: break
@ -260,7 +302,6 @@ def Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin
with open(maintex, 'r', encoding='utf-8', errors='replace') as f: with open(maintex, 'r', encoding='utf-8', errors='replace') as f:
content = f.read() content = f.read()
merged_content = merge_tex_files(project_folder, content, mode) merged_content = merge_tex_files(project_folder, content, mode)
merged_content = re.sub(r'(?<!\\)%.*', '', merged_content) # 使用正则表达式查找注释, 并替换为空字符串
with open(project_folder + '/merge.tex', 'w', encoding='utf-8', errors='replace') as f: with open(project_folder + '/merge.tex', 'w', encoding='utf-8', errors='replace') as f:
f.write(merged_content) f.write(merged_content)
@ -362,7 +403,7 @@ def 编译Latex差别(chatbot, history, main_file_original, main_file_modified,
import os, time import os, time
current_dir = os.getcwd() current_dir = os.getcwd()
n_fix = 0 n_fix = 0
chatbot.append([f"正在编译PDF文档", '编译已经开始。当前工作路径为{work_folder}如果程序停顿5分钟以上则大概率是卡死在Latex里面了。不幸卡死时请直接去该路径下取回翻译结果或者重启之后再度尝试 ...']); yield from update_ui(chatbot=chatbot, history=history) chatbot.append([f"正在编译PDF文档", f'编译已经开始。当前工作路径为{work_folder}如果程序停顿5分钟以上则大概率是卡死在Latex里面了。不幸卡死时请直接去该路径下取回翻译结果或者重启之后再度尝试 ...']); yield from update_ui(chatbot=chatbot, history=history)
chatbot.append([f"正在编译PDF文档", '...']); yield from update_ui(chatbot=chatbot, history=history); time.sleep(1); chatbot[-1] = list(chatbot[-1]) # 刷新界面 chatbot.append([f"正在编译PDF文档", '...']); yield from update_ui(chatbot=chatbot, history=history); time.sleep(1); chatbot[-1] = list(chatbot[-1]) # 刷新界面
yield from update_ui_lastest_msg('编译已经开始...', chatbot, history) # 刷新Gradio前端界面 yield from update_ui_lastest_msg('编译已经开始...', chatbot, history) # 刷新Gradio前端界面