Merge branch 'master' of https://github.com/OverKit/gpt_academic into OverKit-master

This commit is contained in:
qingxu fu 2023-06-27 16:14:12 +08:00
commit e90048a671

View File

@ -8,24 +8,30 @@ pj = os.path.join
""" """
======================================================================== ========================================================================
Part One Part One
Latex segmentation to a linklist Latex segmentation with a binary mask (PRESERVE=0, TRANSFORM=1)
======================================================================== ========================================================================
""" """
PRESERVE = 0 PRESERVE = 0
TRANSFORM = 1 TRANSFORM = 1
def split_worker(text, mask, pattern, flags=0): def set_forbidden_text(text, mask, pattern, flags=0):
""" """
Add a preserve text area in this paper Add a preserve text area in this paper
e.g. with pattern = r"\\begin\{algorithm\}(.*?)\\end\{algorithm\}"
you can mask out (mask = PRESERVE so that text become untouchable for GPT)
everything between "\begin{equation}" and "\end{equation}"
""" """
pattern_compile = re.compile(pattern, flags) pattern_compile = re.compile(pattern, flags)
for res in pattern_compile.finditer(text): for res in pattern_compile.finditer(text):
mask[res.span()[0]:res.span()[1]] = PRESERVE mask[res.span()[0]:res.span()[1]] = PRESERVE
return text, mask return text, mask
def split_worker_careful_brace(text, mask, pattern, flags=0): def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
""" """
Move area into preserve area Add a preserve text area in this paper (text become untouchable for GPT).
count the number of the braces so as to catch compelete text area.
e.g.
\caption{blablablablabla\texbf{blablabla}blablabla.}
""" """
pattern_compile = re.compile(pattern, flags) pattern_compile = re.compile(pattern, flags)
for res in pattern_compile.finditer(text): for res in pattern_compile.finditer(text):
@ -40,9 +46,12 @@ def split_worker_careful_brace(text, mask, pattern, flags=0):
mask[begin:end] = PRESERVE mask[begin:end] = PRESERVE
return text, mask return text, mask
def split_worker_reverse_careful_brace(text, mask, pattern, flags=0): def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0):
""" """
Move area out of preserve area Move area out of preserve area (make text editable for GPT)
count the number of the braces so as to catch compelete text area.
e.g.
\caption{blablablablabla\texbf{blablabla}blablabla.}
""" """
pattern_compile = re.compile(pattern, flags) pattern_compile = re.compile(pattern, flags)
for res in pattern_compile.finditer(text): for res in pattern_compile.finditer(text):
@ -57,7 +66,7 @@ def split_worker_reverse_careful_brace(text, mask, pattern, flags=0):
mask[begin:end] = TRANSFORM mask[begin:end] = TRANSFORM
return text, mask return text, mask
def split_worker_begin_end(text, mask, pattern, flags=0, limit_n_lines=42): def set_forbidden_text_begin_end(text, mask, pattern, flags=0, limit_n_lines=42):
""" """
Find all \begin{} ... \end{} text block that with less than limit_n_lines lines. Find all \begin{} ... \end{} text block that with less than limit_n_lines lines.
Add it to preserve area Add it to preserve area
@ -283,46 +292,49 @@ def split_subprocess(txt, project_folder, return_dict, opts):
mask = np.zeros(len(txt), dtype=np.uint8) + TRANSFORM mask = np.zeros(len(txt), dtype=np.uint8) + TRANSFORM
# 吸收title与作者以上的部分 # 吸收title与作者以上的部分
text, mask = split_worker(text, mask, r"(.*?)\\maketitle", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"(.*?)\\maketitle", re.DOTALL)
# 删除iffalse注释 # 删除iffalse注释
text, mask = split_worker(text, mask, r"\\iffalse(.*?)\\fi", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\iffalse(.*?)\\fi", re.DOTALL)
# 吸收在25行以内的begin-end组合 # 吸收在25行以内的begin-end组合
text, mask = split_worker_begin_end(text, mask, r"\\begin\{([a-z\*]*)\}(.*?)\\end\{\1\}", re.DOTALL, limit_n_lines=25) text, mask = set_forbidden_text_begin_end(text, mask, r"\\begin\{([a-z\*]*)\}(.*?)\\end\{\1\}", re.DOTALL, limit_n_lines=42)
# 吸收匿名公式 # 吸收匿名公式
text, mask = split_worker(text, mask, r"\$\$(.*?)\$\$", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\$\$(.*?)\$\$", re.DOTALL)
text, mask = set_forbidden_text(text, mask, r"\\\[.*?\\\]", re.DOTALL)
# 吸收其他杂项 # 吸收其他杂项
text, mask = split_worker(text, mask, r"\\section\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\section\{(.*?)\}")
text, mask = split_worker(text, mask, r"\\section\*\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\section\*\{(.*?)\}")
text, mask = split_worker(text, mask, r"\\subsection\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\subsection\{(.*?)\}")
text, mask = split_worker(text, mask, r"\\subsubsection\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\subsubsection\{(.*?)\}")
text, mask = split_worker(text, mask, r"\\bibliography\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\bibliography\{(.*?)\}")
text, mask = split_worker(text, mask, r"\\bibliographystyle\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\bibliographystyle\{(.*?)\}")
text, mask = split_worker(text, mask, r"\\begin\{lstlisting\}(.*?)\\end\{lstlisting\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{thebibliography\}.*?\\end\{thebibliography\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{wraptable\}(.*?)\\end\{wraptable\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{lstlisting\}(.*?)\\end\{lstlisting\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{algorithm\}(.*?)\\end\{algorithm\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{wraptable\}(.*?)\\end\{wraptable\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{wrapfigure\}(.*?)\\end\{wrapfigure\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{algorithm\}(.*?)\\end\{algorithm\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{wrapfigure\*\}(.*?)\\end\{wrapfigure\*\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{wrapfigure\}(.*?)\\end\{wrapfigure\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{figure\}(.*?)\\end\{figure\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{wrapfigure\*\}(.*?)\\end\{wrapfigure\*\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{figure\*\}(.*?)\\end\{figure\*\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{figure\}(.*?)\\end\{figure\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{multline\}(.*?)\\end\{multline\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{figure\*\}(.*?)\\end\{figure\*\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{multline\*\}(.*?)\\end\{multline\*\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{multline\}(.*?)\\end\{multline\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{table\}(.*?)\\end\{table\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{multline\*\}(.*?)\\end\{multline\*\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{table\*\}(.*?)\\end\{table\*\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{table\}(.*?)\\end\{table\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{minipage\}(.*?)\\end\{minipage\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{table\*\}(.*?)\\end\{table\*\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{minipage\*\}(.*?)\\end\{minipage\*\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{minipage\}(.*?)\\end\{minipage\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{align\*\}(.*?)\\end\{align\*\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{minipage\*\}(.*?)\\end\{minipage\*\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{align\}(.*?)\\end\{align\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{align\*\}(.*?)\\end\{align\*\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{equation\}(.*?)\\end\{equation\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{align\}(.*?)\\end\{align\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\begin\{equation\*\}(.*?)\\end\{equation\*\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\begin\{equation\}(.*?)\\end\{equation\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\includepdf\[(.*?)\]\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\begin\{equation\*\}(.*?)\\end\{equation\*\}", re.DOTALL)
text, mask = split_worker(text, mask, r"\\item ") text, mask = set_forbidden_text(text, mask, r"\\includepdf\[(.*?)\]\{(.*?)\}")
text, mask = split_worker(text, mask, r"\\label\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\item ")
text, mask = split_worker(text, mask, r"\\begin\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\label\{(.*?)\}")
text, mask = split_worker(text, mask, r"\\vspace\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\begin\{(.*?)\}")
text, mask = split_worker(text, mask, r"\\hspace\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\vspace\{(.*?)\}")
text, mask = split_worker(text, mask, r"\\end\{(.*?)\}") text, mask = set_forbidden_text(text, mask, r"\\hspace\{(.*?)\}")
text, mask = split_worker_careful_brace(text, mask, r"\\hl\{(.*?)\}", re.DOTALL) text, mask = set_forbidden_text(text, mask, r"\\end\{(.*?)\}")
text, mask = split_worker_reverse_careful_brace(text, mask, r"\\caption\{(.*?)\}", re.DOTALL) text, mask = set_forbidden_text_careful_brace(text, mask, r"\\hl\{(.*?)\}", re.DOTALL)
# reverse 操作必须放在最后
text, mask = reverse_forbidden_text_careful_brace(text, mask, r"\\caption\{(.*?)\}", re.DOTALL)
root = convert_to_linklist(text, mask) root = convert_to_linklist(text, mask)
# 修复括号 # 修复括号
@ -448,7 +460,9 @@ class LatexPaperSplit():
if mode == 'translate_zh': if mode == 'translate_zh':
pattern = re.compile(r'\\begin\{abstract\}.*\n') pattern = re.compile(r'\\begin\{abstract\}.*\n')
match = pattern.search(result_string) match = pattern.search(result_string)
assert match is not None, "Cannot find paper abstract section!" if not match:
pattern = re.compile(r'\\abstract\{')
match = pattern.search(result_string)
position = match.end() position = match.end()
result_string = result_string[:position] + self.msg + msg + self.msg_declare + result_string[position:] result_string = result_string[:position] + self.msg + msg + self.msg_declare + result_string[position:]
return result_string return result_string