339 lines
12 KiB
Python
339 lines
12 KiB
Python
import markdown
|
||
import re
|
||
import os
|
||
import math
|
||
from textwrap import dedent
|
||
from functools import lru_cache
|
||
from pymdownx.superfences import fence_code_format
|
||
from latex2mathml.converter import convert as tex2mathml
|
||
from shared_utils.config_loader import get_conf as get_conf
|
||
from shared_utils.text_mask import apply_gpt_academic_string_mask
|
||
|
||
markdown_extension_configs = {
|
||
"mdx_math": {
|
||
"enable_dollar_delimiter": True,
|
||
"use_gitlab_delimiters": False,
|
||
},
|
||
}
|
||
|
||
code_highlight_configs = {
|
||
"pymdownx.superfences": {
|
||
"css_class": "codehilite",
|
||
"custom_fences": [
|
||
{"name": "mermaid", "class": "mermaid", "format": fence_code_format}
|
||
],
|
||
},
|
||
"pymdownx.highlight": {
|
||
"css_class": "codehilite",
|
||
"guess_lang": True,
|
||
# 'auto_title': True,
|
||
# 'linenums': True
|
||
},
|
||
}
|
||
|
||
|
||
def text_divide_paragraph(text):
|
||
"""
|
||
将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
|
||
"""
|
||
pre = '<div class="markdown-body">'
|
||
suf = "</div>"
|
||
if text.startswith(pre) and text.endswith(suf):
|
||
return text
|
||
|
||
if "```" in text:
|
||
# careful input
|
||
return text
|
||
elif "</div>" in text:
|
||
# careful input
|
||
return text
|
||
else:
|
||
# whatever input
|
||
lines = text.split("\n")
|
||
for i, line in enumerate(lines):
|
||
lines[i] = lines[i].replace(" ", " ")
|
||
text = "</br>".join(lines)
|
||
return pre + text + suf
|
||
|
||
|
||
def tex2mathml_catch_exception(content, *args, **kwargs):
|
||
try:
|
||
content = tex2mathml(content, *args, **kwargs)
|
||
except:
|
||
content = content
|
||
return content
|
||
|
||
|
||
def replace_math_no_render(match):
|
||
content = match.group(1)
|
||
if "mode=display" in match.group(0):
|
||
content = content.replace("\n", "</br>")
|
||
return f'<font color="#00FF00">$$</font><font color="#FF00FF">{content}</font><font color="#00FF00">$$</font>'
|
||
else:
|
||
return f'<font color="#00FF00">$</font><font color="#FF00FF">{content}</font><font color="#00FF00">$</font>'
|
||
|
||
|
||
def replace_math_render(match):
|
||
content = match.group(1)
|
||
if "mode=display" in match.group(0):
|
||
if "\\begin{aligned}" in content:
|
||
content = content.replace("\\begin{aligned}", "\\begin{array}")
|
||
content = content.replace("\\end{aligned}", "\\end{array}")
|
||
content = content.replace("&", " ")
|
||
content = tex2mathml_catch_exception(content, display="block")
|
||
return content
|
||
else:
|
||
return tex2mathml_catch_exception(content)
|
||
|
||
|
||
def markdown_bug_hunt(content):
|
||
"""
|
||
解决一个mdx_math的bug(单$包裹begin命令时多余<script>)
|
||
"""
|
||
content = content.replace(
|
||
'<script type="math/tex">\n<script type="math/tex; mode=display">',
|
||
'<script type="math/tex; mode=display">',
|
||
)
|
||
content = content.replace("</script>\n</script>", "</script>")
|
||
return content
|
||
|
||
|
||
def is_equation(txt):
|
||
"""
|
||
判定是否为公式 | 测试1 写出洛伦兹定律,使用tex格式公式 测试2 给出柯西不等式,使用latex格式 测试3 写出麦克斯韦方程组
|
||
"""
|
||
if "```" in txt and "```reference" not in txt:
|
||
return False
|
||
if "$" not in txt and "\\[" not in txt:
|
||
return False
|
||
mathpatterns = {
|
||
r"(?<!\\|\$)(\$)([^\$]+)(\$)": {"allow_multi_lines": False}, # $...$
|
||
r"(?<!\\)(\$\$)([^\$]+)(\$\$)": {"allow_multi_lines": True}, # $$...$$
|
||
r"(?<!\\)(\\\[)(.+?)(\\\])": {"allow_multi_lines": False}, # \[...\]
|
||
# r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False}, # \(...\)
|
||
# r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True}, # \begin...\end
|
||
# r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False}, # $`...`$
|
||
}
|
||
matches = []
|
||
for pattern, property in mathpatterns.items():
|
||
flags = re.ASCII | re.DOTALL if property["allow_multi_lines"] else re.ASCII
|
||
matches.extend(re.findall(pattern, txt, flags))
|
||
if len(matches) == 0:
|
||
return False
|
||
contain_any_eq = False
|
||
illegal_pattern = re.compile(r"[^\x00-\x7F]|echo")
|
||
for match in matches:
|
||
if len(match) != 3:
|
||
return False
|
||
eq_canidate = match[1]
|
||
if illegal_pattern.search(eq_canidate):
|
||
return False
|
||
else:
|
||
contain_any_eq = True
|
||
return contain_any_eq
|
||
|
||
|
||
def fix_markdown_indent(txt):
|
||
# fix markdown indent
|
||
if (" - " not in txt) or (". " not in txt):
|
||
# do not need to fix, fast escape
|
||
return txt
|
||
# walk through the lines and fix non-standard indentation
|
||
lines = txt.split("\n")
|
||
pattern = re.compile(r"^\s+-")
|
||
activated = False
|
||
for i, line in enumerate(lines):
|
||
if line.startswith("- ") or line.startswith("1. "):
|
||
activated = True
|
||
if activated and pattern.match(line):
|
||
stripped_string = line.lstrip()
|
||
num_spaces = len(line) - len(stripped_string)
|
||
if (num_spaces % 4) == 3:
|
||
num_spaces_should_be = math.ceil(num_spaces / 4) * 4
|
||
lines[i] = " " * num_spaces_should_be + stripped_string
|
||
return "\n".join(lines)
|
||
|
||
|
||
FENCED_BLOCK_RE = re.compile(
|
||
dedent(
|
||
r"""
|
||
(?P<fence>^[ \t]*(?:~{3,}|`{3,}))[ ]* # opening fence
|
||
((\{(?P<attrs>[^\}\n]*)\})| # (optional {attrs} or
|
||
(\.?(?P<lang>[\w#.+-]*)[ ]*)? # optional (.)lang
|
||
(hl_lines=(?P<quot>"|')(?P<hl_lines>.*?)(?P=quot)[ ]*)?) # optional hl_lines)
|
||
\n # newline (end of opening fence)
|
||
(?P<code>.*?)(?<=\n) # the code block
|
||
(?P=fence)[ ]*$ # closing fence
|
||
"""
|
||
),
|
||
re.MULTILINE | re.DOTALL | re.VERBOSE,
|
||
)
|
||
|
||
|
||
def get_line_range(re_match_obj, txt):
|
||
start_pos, end_pos = re_match_obj.regs[0]
|
||
num_newlines_before = txt[: start_pos + 1].count("\n")
|
||
line_start = num_newlines_before
|
||
line_end = num_newlines_before + txt[start_pos:end_pos].count("\n") + 1
|
||
return line_start, line_end
|
||
|
||
|
||
def fix_code_segment_indent(txt):
|
||
lines = []
|
||
change_any = False
|
||
txt_tmp = txt
|
||
while True:
|
||
re_match_obj = FENCED_BLOCK_RE.search(txt_tmp)
|
||
if not re_match_obj:
|
||
break
|
||
if len(lines) == 0:
|
||
lines = txt.split("\n")
|
||
|
||
# 清空 txt_tmp 对应的位置方便下次搜索
|
||
start_pos, end_pos = re_match_obj.regs[0]
|
||
txt_tmp = txt_tmp[:start_pos] + " " * (end_pos - start_pos) + txt_tmp[end_pos:]
|
||
line_start, line_end = get_line_range(re_match_obj, txt)
|
||
|
||
# 获取公共缩进
|
||
shared_indent_cnt = 1e5
|
||
for i in range(line_start, line_end):
|
||
stripped_string = lines[i].lstrip()
|
||
num_spaces = len(lines[i]) - len(stripped_string)
|
||
if num_spaces < shared_indent_cnt:
|
||
shared_indent_cnt = num_spaces
|
||
|
||
# 修复缩进
|
||
if (shared_indent_cnt < 1e5) and (shared_indent_cnt % 4) == 3:
|
||
num_spaces_should_be = math.ceil(shared_indent_cnt / 4) * 4
|
||
for i in range(line_start, line_end):
|
||
add_n = num_spaces_should_be - shared_indent_cnt
|
||
lines[i] = " " * add_n + lines[i]
|
||
if not change_any: # 遇到第一个
|
||
change_any = True
|
||
|
||
if change_any:
|
||
return "\n".join(lines)
|
||
else:
|
||
return txt
|
||
|
||
|
||
@lru_cache(maxsize=128) # 使用 lru缓存 加快转换速度
|
||
def markdown_convertion(txt):
|
||
"""
|
||
将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
|
||
"""
|
||
pre = '<div class="markdown-body">'
|
||
suf = "</div>"
|
||
if txt.startswith(pre) and txt.endswith(suf):
|
||
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
|
||
return txt # 已经被转化过,不需要再次转化
|
||
|
||
find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
|
||
|
||
txt = fix_markdown_indent(txt)
|
||
# txt = fix_code_segment_indent(txt)
|
||
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
|
||
# convert everything to html format
|
||
split = markdown.markdown(text="---")
|
||
convert_stage_1 = markdown.markdown(
|
||
text=txt,
|
||
extensions=[
|
||
"sane_lists",
|
||
"tables",
|
||
"mdx_math",
|
||
"pymdownx.superfences",
|
||
"pymdownx.highlight",
|
||
],
|
||
extension_configs={**markdown_extension_configs, **code_highlight_configs},
|
||
)
|
||
convert_stage_1 = markdown_bug_hunt(convert_stage_1)
|
||
# 1. convert to easy-to-copy tex (do not render math)
|
||
convert_stage_2_1, n = re.subn(
|
||
find_equation_pattern,
|
||
replace_math_no_render,
|
||
convert_stage_1,
|
||
flags=re.DOTALL,
|
||
)
|
||
# 2. convert to rendered equation
|
||
convert_stage_2_2, n = re.subn(
|
||
find_equation_pattern, replace_math_render, convert_stage_1, flags=re.DOTALL
|
||
)
|
||
# cat them together
|
||
return pre + convert_stage_2_1 + f"{split}" + convert_stage_2_2 + suf
|
||
else:
|
||
return (
|
||
pre
|
||
+ markdown.markdown(
|
||
txt,
|
||
extensions=[
|
||
"sane_lists",
|
||
"tables",
|
||
"pymdownx.superfences",
|
||
"pymdownx.highlight",
|
||
],
|
||
extension_configs=code_highlight_configs,
|
||
)
|
||
+ suf
|
||
)
|
||
|
||
|
||
def close_up_code_segment_during_stream(gpt_reply):
|
||
"""
|
||
在gpt输出代码的中途(输出了前面的```,但还没输出完后面的```),补上后面的```
|
||
|
||
Args:
|
||
gpt_reply (str): GPT模型返回的回复字符串。
|
||
|
||
Returns:
|
||
str: 返回一个新的字符串,将输出代码片段的“后面的```”补上。
|
||
|
||
"""
|
||
if "```" not in gpt_reply:
|
||
return gpt_reply
|
||
if gpt_reply.endswith("```"):
|
||
return gpt_reply
|
||
|
||
# 排除了以上两个情况,我们
|
||
segments = gpt_reply.split("```")
|
||
n_mark = len(segments) - 1
|
||
if n_mark % 2 == 1:
|
||
return gpt_reply + "\n```" # 输出代码片段中!
|
||
else:
|
||
return gpt_reply
|
||
|
||
def simple_markdown_convertion(txt):
|
||
pre = '<div class="markdown-body">'
|
||
suf = "</div>"
|
||
if txt.startswith(pre) and txt.endswith(suf):
|
||
return txt # 已经被转化过,不需要再次转化
|
||
txt = markdown.markdown(
|
||
txt,
|
||
extensions=["pymdownx.superfences", "tables", "pymdownx.highlight"],
|
||
extension_configs=code_highlight_configs,
|
||
)
|
||
return pre + txt + suf
|
||
|
||
def format_io(self, y):
|
||
"""
|
||
将输入和输出解析为HTML格式。将y中最后一项的输入部分段落化,并将输出部分的Markdown和数学公式转换为HTML格式。
|
||
"""
|
||
if y is None or y == []:
|
||
return []
|
||
i_ask, gpt_reply = y[-1]
|
||
i_ask = apply_gpt_academic_string_mask(i_ask, mode="show_render")
|
||
gpt_reply = apply_gpt_academic_string_mask(gpt_reply, mode="show_render")
|
||
# 输入部分太自由,预处理一波
|
||
if i_ask is not None:
|
||
i_ask = text_divide_paragraph(i_ask)
|
||
# 当代码输出半截的时候,试着补上后个```
|
||
if gpt_reply is not None:
|
||
gpt_reply = close_up_code_segment_during_stream(gpt_reply)
|
||
# 处理提问与输出
|
||
y[-1] = (
|
||
# 输入部分
|
||
None if i_ask is None else simple_markdown_convertion(i_ask),
|
||
# 输出部分
|
||
None if gpt_reply is None else markdown_convertion(gpt_reply),
|
||
)
|
||
return y
|