format source code

This commit is contained in:
binary-husky 2024-01-13 18:00:52 +08:00
parent 1714116a89
commit 7ab379688e
12 changed files with 1049 additions and 595 deletions

View File

@ -5,7 +5,7 @@ import glob, os, requests, time
pj = os.path.join pj = os.path.join
ARXIV_CACHE_DIR = os.path.expanduser(f"~/arxiv_cache/") ARXIV_CACHE_DIR = os.path.expanduser(f"~/arxiv_cache/")
# =================================== 工具函数 =============================================== # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- 工具函数 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
# 专业词汇声明 = 'If the term "agent" is used in this section, it should be translated to "智能体". ' # 专业词汇声明 = 'If the term "agent" is used in this section, it should be translated to "智能体". '
def switch_prompt(pfg, mode, more_requirement): def switch_prompt(pfg, mode, more_requirement):
""" """
@ -142,7 +142,7 @@ def arxiv_download(chatbot, history, txt, allow_cache=True):
from toolbox import extract_archive from toolbox import extract_archive
extract_archive(file_path=dst, dest_dir=extract_dst) extract_archive(file_path=dst, dest_dir=extract_dst)
return extract_dst, arxiv_id return extract_dst, arxiv_id
# ========================================= 插件主程序1 ===================================================== # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序1 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
@CatchException @CatchException
@ -218,7 +218,7 @@ def Latex英文纠错加PDF对比(txt, llm_kwargs, plugin_kwargs, chatbot, histo
# <-------------- we are done -------------> # <-------------- we are done ------------->
return success return success
# ========================================= 插件主程序2 ===================================================== # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序2 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
@CatchException @CatchException
def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):

View File

@ -1,15 +1,18 @@
import os, shutil import os, shutil
import re import re
import numpy as np import numpy as np
PRESERVE = 0 PRESERVE = 0
TRANSFORM = 1 TRANSFORM = 1
pj = os.path.join pj = os.path.join
class LinkedListNode():
class LinkedListNode:
""" """
Linked List Node Linked List Node
""" """
def __init__(self, string, preserve=True) -> None: def __init__(self, string, preserve=True) -> None:
self.string = string self.string = string
self.preserve = preserve self.preserve = preserve
@ -18,12 +21,14 @@ class LinkedListNode():
# self.begin_line = 0 # self.begin_line = 0
# self.begin_char = 0 # self.begin_char = 0
def convert_to_linklist(text, mask): def convert_to_linklist(text, mask):
root = LinkedListNode("", preserve=True) root = LinkedListNode("", preserve=True)
current_node = root current_node = root
for c, m, i in zip(text, mask, range(len(text))): for c, m, i in zip(text, mask, range(len(text))):
if (m==PRESERVE and current_node.preserve) \ if (m == PRESERVE and current_node.preserve) or (
or (m==TRANSFORM and not current_node.preserve): m == TRANSFORM and not current_node.preserve
):
# add # add
current_node.string += c current_node.string += c
else: else:
@ -31,6 +36,7 @@ def convert_to_linklist(text, mask):
current_node = current_node.next current_node = current_node.next
return root return root
def post_process(root): def post_process(root):
# 修复括号 # 修复括号
node = root node = root
@ -38,21 +44,24 @@ def post_process(root):
string = node.string string = node.string
if node.preserve: if node.preserve:
node = node.next node = node.next
if node is None: break if node is None:
break
continue continue
def break_check(string): def break_check(string):
str_stack = [""] # (lv, index) str_stack = [""] # (lv, index)
for i, c in enumerate(string): for i, c in enumerate(string):
if c == '{': if c == "{":
str_stack.append('{') str_stack.append("{")
elif c == '}': elif c == "}":
if len(str_stack) == 1: if len(str_stack) == 1:
print('stack fix') print("stack fix")
return i return i
str_stack.pop(-1) str_stack.pop(-1)
else: else:
str_stack[-1] += c str_stack[-1] += c
return -1 return -1
bp = break_check(string) bp = break_check(string)
if bp == -1: if bp == -1:
@ -69,51 +78,66 @@ def post_process(root):
node.next = q node.next = q
node = node.next node = node.next
if node is None: break if node is None:
break
# 屏蔽空行和太短的句子 # 屏蔽空行和太短的句子
node = root node = root
while True: while True:
if len(node.string.strip('\n').strip(''))==0: node.preserve = True if len(node.string.strip("\n").strip("")) == 0:
if len(node.string.strip('\n').strip(''))<42: node.preserve = True node.preserve = True
if len(node.string.strip("\n").strip("")) < 42:
node.preserve = True
node = node.next node = node.next
if node is None: break if node is None:
break
node = root node = root
while True: while True:
if node.next and node.preserve and node.next.preserve: if node.next and node.preserve and node.next.preserve:
node.string += node.next.string node.string += node.next.string
node.next = node.next.next node.next = node.next.next
node = node.next node = node.next
if node is None: break if node is None:
break
# 将前后断行符脱离 # 将前后断行符脱离
node = root node = root
prev_node = None prev_node = None
while True: while True:
if not node.preserve: if not node.preserve:
lstriped_ = node.string.lstrip().lstrip('\n') lstriped_ = node.string.lstrip().lstrip("\n")
if (prev_node is not None) and (prev_node.preserve) and (len(lstriped_)!=len(node.string)): if (
(prev_node is not None)
and (prev_node.preserve)
and (len(lstriped_) != len(node.string))
):
prev_node.string += node.string[: -len(lstriped_)] prev_node.string += node.string[: -len(lstriped_)]
node.string = lstriped_ node.string = lstriped_
rstriped_ = node.string.rstrip().rstrip('\n') rstriped_ = node.string.rstrip().rstrip("\n")
if (node.next is not None) and (node.next.preserve) and (len(rstriped_)!=len(node.string)): if (
(node.next is not None)
and (node.next.preserve)
and (len(rstriped_) != len(node.string))
):
node.next.string = node.string[len(rstriped_) :] + node.next.string node.next.string = node.string[len(rstriped_) :] + node.next.string
node.string = rstriped_ node.string = rstriped_
# ===== # =-=-=
prev_node = node prev_node = node
node = node.next node = node.next
if node is None: break if node is None:
break
# 标注节点的行数范围 # 标注节点的行数范围
node = root node = root
n_line = 0 n_line = 0
expansion = 2 expansion = 2
while True: while True:
n_l = node.string.count('\n') n_l = node.string.count("\n")
node.range = [n_line - expansion, n_line + n_l + expansion] # 失败时,扭转的范围 node.range = [n_line - expansion, n_line + n_l + expansion] # 失败时,扭转的范围
n_line = n_line + n_l n_line = n_line + n_l
node = node.next node = node.next
if node is None: break if node is None:
break
return root return root
@ -131,12 +155,14 @@ def set_forbidden_text(text, mask, pattern, flags=0):
you can mask out (mask = PRESERVE so that text become untouchable for GPT) you can mask out (mask = PRESERVE so that text become untouchable for GPT)
everything between "\begin{equation}" and "\end{equation}" everything between "\begin{equation}" and "\end{equation}"
""" """
if isinstance(pattern, list): pattern = '|'.join(pattern) if isinstance(pattern, list):
pattern = "|".join(pattern)
pattern_compile = re.compile(pattern, flags) pattern_compile = re.compile(pattern, flags)
for res in pattern_compile.finditer(text): for res in pattern_compile.finditer(text):
mask[res.span()[0] : res.span()[1]] = PRESERVE mask[res.span()[0] : res.span()[1]] = PRESERVE
return text, mask return text, mask
def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True): def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
""" """
Move area out of preserve area (make text editable for GPT) Move area out of preserve area (make text editable for GPT)
@ -144,7 +170,8 @@ def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
e.g. e.g.
\begin{abstract} blablablablablabla. \end{abstract} \begin{abstract} blablablablablabla. \end{abstract}
""" """
if isinstance(pattern, list): pattern = '|'.join(pattern) if isinstance(pattern, list):
pattern = "|".join(pattern)
pattern_compile = re.compile(pattern, flags) pattern_compile = re.compile(pattern, flags)
for res in pattern_compile.finditer(text): for res in pattern_compile.finditer(text):
if not forbid_wrapper: if not forbid_wrapper:
@ -155,6 +182,7 @@ def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
mask[res.regs[1][1] : res.regs[0][1]] = PRESERVE # abstract mask[res.regs[1][1] : res.regs[0][1]] = PRESERVE # abstract
return text, mask return text, mask
def set_forbidden_text_careful_brace(text, mask, pattern, flags=0): def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
""" """
Add a preserve text area in this paper (text become untouchable for GPT). Add a preserve text area in this paper (text become untouchable for GPT).
@ -167,15 +195,21 @@ def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
brace_level = -1 brace_level = -1
p = begin = end = res.regs[0][0] p = begin = end = res.regs[0][0]
for _ in range(1024 * 16): for _ in range(1024 * 16):
if text[p] == '}' and brace_level == 0: break if text[p] == "}" and brace_level == 0:
elif text[p] == '}': brace_level -= 1 break
elif text[p] == '{': brace_level += 1 elif text[p] == "}":
brace_level -= 1
elif text[p] == "{":
brace_level += 1
p += 1 p += 1
end = p + 1 end = p + 1
mask[begin:end] = PRESERVE mask[begin:end] = PRESERVE
return text, mask return text, mask
def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wrapper=True):
def reverse_forbidden_text_careful_brace(
text, mask, pattern, flags=0, forbid_wrapper=True
):
""" """
Move area out of preserve area (make text editable for GPT) Move area out of preserve area (make text editable for GPT)
count the number of the braces so as to catch compelete text area. count the number of the braces so as to catch compelete text area.
@ -187,9 +221,12 @@ def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wr
brace_level = 0 brace_level = 0
p = begin = end = res.regs[1][0] p = begin = end = res.regs[1][0]
for _ in range(1024 * 16): for _ in range(1024 * 16):
if text[p] == '}' and brace_level == 0: break if text[p] == "}" and brace_level == 0:
elif text[p] == '}': brace_level -= 1 break
elif text[p] == '{': brace_level += 1 elif text[p] == "}":
brace_level -= 1
elif text[p] == "{":
brace_level += 1
p += 1 p += 1
end = p end = p
mask[begin:end] = TRANSFORM mask[begin:end] = TRANSFORM
@ -198,27 +235,42 @@ def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wr
mask[end : res.regs[0][1]] = PRESERVE mask[end : res.regs[0][1]] = PRESERVE
return text, mask return text, mask
def set_forbidden_text_begin_end(text, mask, pattern, flags=0, limit_n_lines=42): def set_forbidden_text_begin_end(text, mask, pattern, flags=0, limit_n_lines=42):
""" """
Find all \begin{} ... \end{} text block that with less than limit_n_lines lines. Find all \begin{} ... \end{} text block that with less than limit_n_lines lines.
Add it to preserve area Add it to preserve area
""" """
pattern_compile = re.compile(pattern, flags) pattern_compile = re.compile(pattern, flags)
def search_with_line_limit(text, mask): def search_with_line_limit(text, mask):
for res in pattern_compile.finditer(text): for res in pattern_compile.finditer(text):
cmd = res.group(1) # begin{what} cmd = res.group(1) # begin{what}
this = res.group(2) # content between begin and end this = res.group(2) # content between begin and end
this_mask = mask[res.regs[2][0] : res.regs[2][1]] this_mask = mask[res.regs[2][0] : res.regs[2][1]]
white_list = ['document', 'abstract', 'lemma', 'definition', 'sproof', white_list = [
'em', 'emph', 'textit', 'textbf', 'itemize', 'enumerate'] "document",
if (cmd in white_list) or this.count('\n') >= limit_n_lines: # use a magical number 42 "abstract",
"lemma",
"definition",
"sproof",
"em",
"emph",
"textit",
"textbf",
"itemize",
"enumerate",
]
if (cmd in white_list) or this.count(
"\n"
) >= limit_n_lines: # use a magical number 42
this, this_mask = search_with_line_limit(this, this_mask) this, this_mask = search_with_line_limit(this, this_mask)
mask[res.regs[2][0] : res.regs[2][1]] = this_mask mask[res.regs[2][0] : res.regs[2][1]] = this_mask
else: else:
mask[res.regs[0][0] : res.regs[0][1]] = PRESERVE mask[res.regs[0][0] : res.regs[0][1]] = PRESERVE
return text, mask return text, mask
return search_with_line_limit(text, mask)
return search_with_line_limit(text, mask)
""" """
@ -227,6 +279,7 @@ Latex Merge File
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
""" """
def find_main_tex_file(file_manifest, mode): def find_main_tex_file(file_manifest, mode):
""" """
在多Tex文档中寻找主文件必须包含documentclass返回找到的第一个 在多Tex文档中寻找主文件必须包含documentclass返回找到的第一个
@ -234,27 +287,36 @@ def find_main_tex_file(file_manifest, mode):
""" """
canidates = [] canidates = []
for texf in file_manifest: for texf in file_manifest:
if os.path.basename(texf).startswith('merge'): if os.path.basename(texf).startswith("merge"):
continue continue
with open(texf, 'r', encoding='utf8', errors='ignore') as f: with open(texf, "r", encoding="utf8", errors="ignore") as f:
file_content = f.read() file_content = f.read()
if r'\documentclass' in file_content: if r"\documentclass" in file_content:
canidates.append(texf) canidates.append(texf)
else: else:
continue continue
if len(canidates) == 0: if len(canidates) == 0:
raise RuntimeError('无法找到一个主Tex文件包含documentclass关键字') raise RuntimeError("无法找到一个主Tex文件包含documentclass关键字")
elif len(canidates) == 1: elif len(canidates) == 1:
return canidates[0] return canidates[0]
else: # if len(canidates) >= 2 通过一些Latex模板中常见但通常不会出现在正文的单词对不同latex源文件扣分取评分最高者返回 else: # if len(canidates) >= 2 通过一些Latex模板中常见但通常不会出现在正文的单词对不同latex源文件扣分取评分最高者返回
canidates_score = [] canidates_score = []
# 给出一些判定模板文档的词作为扣分项 # 给出一些判定模板文档的词作为扣分项
unexpected_words = ['\\LaTeX', 'manuscript', 'Guidelines', 'font', 'citations', 'rejected', 'blind review', 'reviewers'] unexpected_words = [
expected_words = ['\\input', '\\ref', '\\cite'] "\\LaTeX",
"manuscript",
"Guidelines",
"font",
"citations",
"rejected",
"blind review",
"reviewers",
]
expected_words = ["\\input", "\\ref", "\\cite"]
for texf in canidates: for texf in canidates:
canidates_score.append(0) canidates_score.append(0)
with open(texf, 'r', encoding='utf8', errors='ignore') as f: with open(texf, "r", encoding="utf8", errors="ignore") as f:
file_content = f.read() file_content = f.read()
file_content = rm_comments(file_content) file_content = rm_comments(file_content)
for uw in unexpected_words: for uw in unexpected_words:
@ -266,6 +328,7 @@ def find_main_tex_file(file_manifest, mode):
select = np.argmax(canidates_score) # 取评分最高者返回 select = np.argmax(canidates_score) # 取评分最高者返回
return canidates[select] return canidates[select]
def rm_comments(main_file): def rm_comments(main_file):
new_file_remove_comment_lines = [] new_file_remove_comment_lines = []
for l in main_file.splitlines(): for l in main_file.splitlines():
@ -274,30 +337,39 @@ def rm_comments(main_file):
pass pass
else: else:
new_file_remove_comment_lines.append(l) new_file_remove_comment_lines.append(l)
main_file = '\n'.join(new_file_remove_comment_lines) main_file = "\n".join(new_file_remove_comment_lines)
# main_file = re.sub(r"\\include{(.*?)}", r"\\input{\1}", main_file) # 将 \include 命令转换为 \input 命令 # main_file = re.sub(r"\\include{(.*?)}", r"\\input{\1}", main_file) # 将 \include 命令转换为 \input 命令
main_file = re.sub(r'(?<!\\)%.*', '', main_file) # 使用正则表达式查找半行注释, 并替换为空字符串 main_file = re.sub(r"(?<!\\)%.*", "", main_file) # 使用正则表达式查找半行注释, 并替换为空字符串
return main_file return main_file
def find_tex_file_ignore_case(fp): def find_tex_file_ignore_case(fp):
dir_name = os.path.dirname(fp) dir_name = os.path.dirname(fp)
base_name = os.path.basename(fp) base_name = os.path.basename(fp)
# 如果输入的文件路径是正确的 # 如果输入的文件路径是正确的
if os.path.isfile(pj(dir_name, base_name)): return pj(dir_name, base_name) if os.path.isfile(pj(dir_name, base_name)):
return pj(dir_name, base_name)
# 如果不正确,试着加上.tex后缀试试 # 如果不正确,试着加上.tex后缀试试
if not base_name.endswith('.tex'): base_name+='.tex' if not base_name.endswith(".tex"):
if os.path.isfile(pj(dir_name, base_name)): return pj(dir_name, base_name) base_name += ".tex"
if os.path.isfile(pj(dir_name, base_name)):
return pj(dir_name, base_name)
# 如果还找不到,解除大小写限制,再试一次 # 如果还找不到,解除大小写限制,再试一次
import glob import glob
for f in glob.glob(dir_name+'/*.tex'):
for f in glob.glob(dir_name + "/*.tex"):
base_name_s = os.path.basename(fp) base_name_s = os.path.basename(fp)
base_name_f = os.path.basename(f) base_name_f = os.path.basename(f)
if base_name_s.lower() == base_name_f.lower(): return f if base_name_s.lower() == base_name_f.lower():
return f
# 试着加上.tex后缀试试 # 试着加上.tex后缀试试
if not base_name_s.endswith('.tex'): base_name_s+='.tex' if not base_name_s.endswith(".tex"):
if base_name_s.lower() == base_name_f.lower(): return f base_name_s += ".tex"
if base_name_s.lower() == base_name_f.lower():
return f
return None return None
def merge_tex_files_(project_foler, main_file, mode): def merge_tex_files_(project_foler, main_file, mode):
""" """
Merge Tex project recrusively Merge Tex project recrusively
@ -309,18 +381,18 @@ def merge_tex_files_(project_foler, main_file, mode):
fp_ = find_tex_file_ignore_case(fp) fp_ = find_tex_file_ignore_case(fp)
if fp_: if fp_:
try: try:
with open(fp_, 'r', encoding='utf-8', errors='replace') as fx: c = fx.read() with open(fp_, "r", encoding="utf-8", errors="replace") as fx:
c = fx.read()
except: except:
c = f"\n\nWarning from GPT-Academic: LaTex source file is missing!\n\n" c = f"\n\nWarning from GPT-Academic: LaTex source file is missing!\n\n"
else: else:
raise RuntimeError(f'找不到{fp}Tex源文件缺失') raise RuntimeError(f"找不到{fp}Tex源文件缺失")
c = merge_tex_files_(project_foler, c, mode) c = merge_tex_files_(project_foler, c, mode)
main_file = main_file[: s.span()[0]] + c + main_file[s.span()[1] :] main_file = main_file[: s.span()[0]] + c + main_file[s.span()[1] :]
return main_file return main_file
def find_title_and_abs(main_file): def find_title_and_abs(main_file):
def extract_abstract_1(text): def extract_abstract_1(text):
pattern = r"\\abstract\{(.*?)\}" pattern = r"\\abstract\{(.*?)\}"
match = re.search(pattern, text, re.DOTALL) match = re.search(pattern, text, re.DOTALL)
@ -362,21 +434,30 @@ def merge_tex_files(project_foler, main_file, mode):
main_file = merge_tex_files_(project_foler, main_file, mode) main_file = merge_tex_files_(project_foler, main_file, mode)
main_file = rm_comments(main_file) main_file = rm_comments(main_file)
if mode == 'translate_zh': if mode == "translate_zh":
# find paper documentclass # find paper documentclass
pattern = re.compile(r'\\documentclass.*\n') pattern = re.compile(r"\\documentclass.*\n")
match = pattern.search(main_file) match = pattern.search(main_file)
assert match is not None, "Cannot find documentclass statement!" assert match is not None, "Cannot find documentclass statement!"
position = match.end() position = match.end()
add_ctex = '\\usepackage{ctex}\n' add_ctex = "\\usepackage{ctex}\n"
add_url = '\\usepackage{url}\n' if '{url}' not in main_file else '' add_url = "\\usepackage{url}\n" if "{url}" not in main_file else ""
main_file = main_file[:position] + add_ctex + add_url + main_file[position:] main_file = main_file[:position] + add_ctex + add_url + main_file[position:]
# fontset=windows # fontset=windows
import platform import platform
main_file = re.sub(r"\\documentclass\[(.*?)\]{(.*?)}", r"\\documentclass[\1,fontset=windows,UTF8]{\2}",main_file)
main_file = re.sub(r"\\documentclass{(.*?)}", r"\\documentclass[fontset=windows,UTF8]{\1}",main_file) main_file = re.sub(
r"\\documentclass\[(.*?)\]{(.*?)}",
r"\\documentclass[\1,fontset=windows,UTF8]{\2}",
main_file,
)
main_file = re.sub(
r"\\documentclass{(.*?)}",
r"\\documentclass[fontset=windows,UTF8]{\1}",
main_file,
)
# find paper abstract # find paper abstract
pattern_opt1 = re.compile(r'\\begin\{abstract\}.*\n') pattern_opt1 = re.compile(r"\\begin\{abstract\}.*\n")
pattern_opt2 = re.compile(r"\\abstract\{(.*?)\}", flags=re.DOTALL) pattern_opt2 = re.compile(r"\\abstract\{(.*?)\}", flags=re.DOTALL)
match_opt1 = pattern_opt1.search(main_file) match_opt1 = pattern_opt1.search(main_file)
match_opt2 = pattern_opt2.search(main_file) match_opt2 = pattern_opt2.search(main_file)
@ -385,7 +466,9 @@ def merge_tex_files(project_foler, main_file, mode):
main_file = insert_abstract(main_file) main_file = insert_abstract(main_file)
match_opt1 = pattern_opt1.search(main_file) match_opt1 = pattern_opt1.search(main_file)
match_opt2 = pattern_opt2.search(main_file) match_opt2 = pattern_opt2.search(main_file)
assert (match_opt1 is not None) or (match_opt2 is not None), "Cannot find paper abstract section!" assert (match_opt1 is not None) or (
match_opt2 is not None
), "Cannot find paper abstract section!"
return main_file return main_file
@ -395,6 +478,7 @@ The GPT-Academic program cannot find abstract section in this paper.
\end{abstract} \end{abstract}
""" """
def insert_abstract(tex_content): def insert_abstract(tex_content):
if "\\maketitle" in tex_content: if "\\maketitle" in tex_content:
# find the position of "\maketitle" # find the position of "\maketitle"
@ -402,7 +486,13 @@ def insert_abstract(tex_content):
# find the nearest ending line # find the nearest ending line
end_line_index = tex_content.find("\n", find_index) end_line_index = tex_content.find("\n", find_index)
# insert "abs_str" on the next line # insert "abs_str" on the next line
modified_tex = tex_content[:end_line_index+1] + '\n\n' + insert_missing_abs_str + '\n\n' + tex_content[end_line_index+1:] modified_tex = (
tex_content[: end_line_index + 1]
+ "\n\n"
+ insert_missing_abs_str
+ "\n\n"
+ tex_content[end_line_index + 1 :]
)
return modified_tex return modified_tex
elif r"\begin{document}" in tex_content: elif r"\begin{document}" in tex_content:
# find the position of "\maketitle" # find the position of "\maketitle"
@ -410,16 +500,25 @@ def insert_abstract(tex_content):
# find the nearest ending line # find the nearest ending line
end_line_index = tex_content.find("\n", find_index) end_line_index = tex_content.find("\n", find_index)
# insert "abs_str" on the next line # insert "abs_str" on the next line
modified_tex = tex_content[:end_line_index+1] + '\n\n' + insert_missing_abs_str + '\n\n' + tex_content[end_line_index+1:] modified_tex = (
tex_content[: end_line_index + 1]
+ "\n\n"
+ insert_missing_abs_str
+ "\n\n"
+ tex_content[end_line_index + 1 :]
)
return modified_tex return modified_tex
else: else:
return tex_content return tex_content
""" """
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
Post process Post process
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
""" """
def mod_inbraket(match): def mod_inbraket(match):
""" """
为啥chatgpt会把cite里面的逗号换成中文逗号呀 为啥chatgpt会把cite里面的逗号换成中文逗号呀
@ -428,11 +527,12 @@ def mod_inbraket(match):
cmd = match.group(1) cmd = match.group(1)
str_to_modify = match.group(2) str_to_modify = match.group(2)
# modify the matched string # modify the matched string
str_to_modify = str_to_modify.replace('', ':') # 前面是中文冒号,后面是英文冒号 str_to_modify = str_to_modify.replace("", ":") # 前面是中文冒号,后面是英文冒号
str_to_modify = str_to_modify.replace('', ',') # 前面是中文逗号,后面是英文逗号 str_to_modify = str_to_modify.replace("", ",") # 前面是中文逗号,后面是英文逗号
# str_to_modify = 'BOOM' # str_to_modify = 'BOOM'
return "\\" + cmd + "{" + str_to_modify + "}" return "\\" + cmd + "{" + str_to_modify + "}"
def fix_content(final_tex, node_string): def fix_content(final_tex, node_string):
""" """
Fix common GPT errors to increase success rate Fix common GPT errors to increase success rate
@ -444,9 +544,9 @@ def fix_content(final_tex, node_string):
if "Traceback" in final_tex and "[Local Message]" in final_tex: if "Traceback" in final_tex and "[Local Message]" in final_tex:
final_tex = node_string # 出问题了,还原原文 final_tex = node_string # 出问题了,还原原文
if node_string.count('\\begin') != final_tex.count('\\begin'): if node_string.count("\\begin") != final_tex.count("\\begin"):
final_tex = node_string # 出问题了,还原原文 final_tex = node_string # 出问题了,还原原文
if node_string.count('\_') > 0 and node_string.count('\_') > final_tex.count('\_'): if node_string.count("\_") > 0 and node_string.count("\_") > final_tex.count("\_"):
# walk and replace any _ without \ # walk and replace any _ without \
final_tex = re.sub(r"(?<!\\)_", "\\_", final_tex) final_tex = re.sub(r"(?<!\\)_", "\\_", final_tex)
@ -454,24 +554,32 @@ def fix_content(final_tex, node_string):
# this function count the number of { and } # this function count the number of { and }
brace_level = 0 brace_level = 0
for c in string: for c in string:
if c == "{": brace_level += 1 if c == "{":
elif c == "}": brace_level -= 1 brace_level += 1
elif c == "}":
brace_level -= 1
return brace_level return brace_level
def join_most(tex_t, tex_o): def join_most(tex_t, tex_o):
# this function join translated string and original string when something goes wrong # this function join translated string and original string when something goes wrong
p_t = 0 p_t = 0
p_o = 0 p_o = 0
def find_next(string, chars, begin): def find_next(string, chars, begin):
p = begin p = begin
while p < len(string): while p < len(string):
if string[p] in chars: return p, string[p] if string[p] in chars:
return p, string[p]
p += 1 p += 1
return None, None return None, None
while True: while True:
res1, char = find_next(tex_o, ['{','}'], p_o) res1, char = find_next(tex_o, ["{", "}"], p_o)
if res1 is None: break if res1 is None:
break
res2, char = find_next(tex_t, [char], p_t) res2, char = find_next(tex_t, [char], p_t)
if res2 is None: break if res2 is None:
break
p_o = res1 + 1 p_o = res1 + 1
p_t = res2 + 1 p_t = res2 + 1
return tex_t[:p_t] + tex_o[p_o:] return tex_t[:p_t] + tex_o[p_o:]
@ -481,9 +589,13 @@ def fix_content(final_tex, node_string):
final_tex = join_most(final_tex, node_string) final_tex = join_most(final_tex, node_string)
return final_tex return final_tex
def compile_latex_with_timeout(command, cwd, timeout=60): def compile_latex_with_timeout(command, cwd, timeout=60):
import subprocess import subprocess
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd)
process = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd
)
try: try:
stdout, stderr = process.communicate(timeout=timeout) stdout, stderr = process.communicate(timeout=timeout)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
@ -493,43 +605,52 @@ def compile_latex_with_timeout(command, cwd, timeout=60):
return False return False
return True return True
def run_in_subprocess_wrapper_func(func, args, kwargs, return_dict, exception_dict): def run_in_subprocess_wrapper_func(func, args, kwargs, return_dict, exception_dict):
import sys import sys
try: try:
result = func(*args, **kwargs) result = func(*args, **kwargs)
return_dict['result'] = result return_dict["result"] = result
except Exception as e: except Exception as e:
exc_info = sys.exc_info() exc_info = sys.exc_info()
exception_dict['exception'] = exc_info exception_dict["exception"] = exc_info
def run_in_subprocess(func): def run_in_subprocess(func):
import multiprocessing import multiprocessing
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
return_dict = multiprocessing.Manager().dict() return_dict = multiprocessing.Manager().dict()
exception_dict = multiprocessing.Manager().dict() exception_dict = multiprocessing.Manager().dict()
process = multiprocessing.Process(target=run_in_subprocess_wrapper_func, process = multiprocessing.Process(
args=(func, args, kwargs, return_dict, exception_dict)) target=run_in_subprocess_wrapper_func,
args=(func, args, kwargs, return_dict, exception_dict),
)
process.start() process.start()
process.join() process.join()
process.close() process.close()
if 'exception' in exception_dict: if "exception" in exception_dict:
# ooops, the subprocess ran into an exception # ooops, the subprocess ran into an exception
exc_info = exception_dict['exception'] exc_info = exception_dict["exception"]
raise exc_info[1].with_traceback(exc_info[2]) raise exc_info[1].with_traceback(exc_info[2])
if 'result' in return_dict.keys(): if "result" in return_dict.keys():
# If the subprocess ran successfully, return the result # If the subprocess ran successfully, return the result
return return_dict['result'] return return_dict["result"]
return wrapper return wrapper
def _merge_pdfs(pdf1_path, pdf2_path, output_path): def _merge_pdfs(pdf1_path, pdf2_path, output_path):
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题把它放到子进程中运行从而方便内存的释放 import PyPDF2 # PyPDF2这个库有严重的内存泄露问题把它放到子进程中运行从而方便内存的释放
Percent = 0.95 Percent = 0.95
# raise RuntimeError('PyPDF2 has a serious memory leak problem, please use other tools to merge PDF files.') # raise RuntimeError('PyPDF2 has a serious memory leak problem, please use other tools to merge PDF files.')
# Open the first PDF file # Open the first PDF file
with open(pdf1_path, 'rb') as pdf1_file: with open(pdf1_path, "rb") as pdf1_file:
pdf1_reader = PyPDF2.PdfFileReader(pdf1_file) pdf1_reader = PyPDF2.PdfFileReader(pdf1_file)
# Open the second PDF file # Open the second PDF file
with open(pdf2_path, 'rb') as pdf2_file: with open(pdf2_path, "rb") as pdf2_file:
pdf2_reader = PyPDF2.PdfFileReader(pdf2_file) pdf2_reader = PyPDF2.PdfFileReader(pdf2_file)
# Create a new PDF file to store the merged pages # Create a new PDF file to store the merged pages
output_writer = PyPDF2.PdfFileWriter() output_writer = PyPDF2.PdfFileWriter()
@ -549,14 +670,25 @@ def _merge_pdfs(pdf1_path, pdf2_path, output_path):
page2 = PyPDF2.PageObject.createBlankPage(pdf1_reader) page2 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
# Create a new empty page with double width # Create a new empty page with double width
new_page = PyPDF2.PageObject.createBlankPage( new_page = PyPDF2.PageObject.createBlankPage(
width = int(int(page1.mediaBox.getWidth()) + int(page2.mediaBox.getWidth()) * Percent), width=int(
height = max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight()) int(page1.mediaBox.getWidth())
+ int(page2.mediaBox.getWidth()) * Percent
),
height=max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight()),
) )
new_page.mergeTranslatedPage(page1, 0, 0) new_page.mergeTranslatedPage(page1, 0, 0)
new_page.mergeTranslatedPage(page2, int(int(page1.mediaBox.getWidth())-int(page2.mediaBox.getWidth())* (1-Percent)), 0) new_page.mergeTranslatedPage(
page2,
int(
int(page1.mediaBox.getWidth())
- int(page2.mediaBox.getWidth()) * (1 - Percent)
),
0,
)
output_writer.addPage(new_page) output_writer.addPage(new_page)
# Save the merged PDF file # Save the merged PDF file
with open(output_path, 'wb') as output_file: with open(output_path, "wb") as output_file:
output_writer.write(output_file) output_writer.write(output_file)
merge_pdfs = run_in_subprocess(_merge_pdfs) # PyPDF2这个库有严重的内存泄露问题把它放到子进程中运行从而方便内存的释放 merge_pdfs = run_in_subprocess(_merge_pdfs) # PyPDF2这个库有严重的内存泄露问题把它放到子进程中运行从而方便内存的释放

View File

@ -352,9 +352,9 @@ def step_1_core_key_translate():
chinese_core_keys_norepeat_mapping.update({k:cached_translation[k]}) chinese_core_keys_norepeat_mapping.update({k:cached_translation[k]})
chinese_core_keys_norepeat_mapping = dict(sorted(chinese_core_keys_norepeat_mapping.items(), key=lambda x: -len(x[0]))) chinese_core_keys_norepeat_mapping = dict(sorted(chinese_core_keys_norepeat_mapping.items(), key=lambda x: -len(x[0])))
# =============================================== # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
# copy # copy
# =============================================== # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
def copy_source_code(): def copy_source_code():
from toolbox import get_conf from toolbox import get_conf
@ -367,9 +367,9 @@ def step_1_core_key_translate():
shutil.copytree('./', backup_dir, ignore=lambda x, y: blacklist) shutil.copytree('./', backup_dir, ignore=lambda x, y: blacklist)
copy_source_code() copy_source_code()
# =============================================== # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
# primary key replace # primary key replace
# =============================================== # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
directory_path = f'./multi-language/{LANG}/' directory_path = f'./multi-language/{LANG}/'
for root, dirs, files in os.walk(directory_path): for root, dirs, files in os.walk(directory_path):
for file in files: for file in files:
@ -389,9 +389,9 @@ def step_1_core_key_translate():
def step_2_core_key_translate(): def step_2_core_key_translate():
# ================================================================================================= # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# step2 # step2
# ================================================================================================= # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
def load_string(strings, string_input): def load_string(strings, string_input):
string_ = string_input.strip().strip(',').strip().strip('.').strip() string_ = string_input.strip().strip(',').strip().strip('.').strip()
@ -492,9 +492,9 @@ def step_2_core_key_translate():
cached_translation.update(read_map_from_json(language=LANG_STD)) cached_translation.update(read_map_from_json(language=LANG_STD))
cached_translation = dict(sorted(cached_translation.items(), key=lambda x: -len(x[0]))) cached_translation = dict(sorted(cached_translation.items(), key=lambda x: -len(x[0])))
# =============================================== # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
# literal key replace # literal key replace
# =============================================== # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
directory_path = f'./multi-language/{LANG}/' directory_path = f'./multi-language/{LANG}/'
for root, dirs, files in os.walk(directory_path): for root, dirs, files in os.walk(directory_path):
for file in files: for file in files:

View File

@ -244,7 +244,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
if has_choices and not choice_valid: if has_choices and not choice_valid:
# 一些垃圾第三方接口的出现这样的错误 # 一些垃圾第三方接口的出现这样的错误
continue continue
if len(chunk_decoded) > 0 and (chunkjson is None): if ('data: [DONE]' not in chunk_decoded) and len(chunk_decoded) > 0 and (chunkjson is None):
# 传递进来一些奇怪的东西 # 传递进来一些奇怪的东西
raise ValueError(f'无法读取以下数据,请检查配置。\n\n{chunk_decoded}') raise ValueError(f'无法读取以下数据,请检查配置。\n\n{chunk_decoded}')
# 前者是API2D的结束条件后者是OPENAI的结束条件 # 前者是API2D的结束条件后者是OPENAI的结束条件

View File

@ -1,16 +1,17 @@
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第一部分来自EdgeGPT.py 第一部分来自EdgeGPT.py
https://github.com/acheong08/EdgeGPT https://github.com/acheong08/EdgeGPT
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
from .edge_gpt_free import Chatbot as NewbingChatbot from .edge_gpt_free import Chatbot as NewbingChatbot
load_message = "等待NewBing响应。" load_message = "等待NewBing响应。"
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第二部分子进程Worker调用主体 第二部分子进程Worker调用主体
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
import time import time
import json import json
@ -22,19 +23,30 @@ import threading
from toolbox import update_ui, get_conf, trimmed_format_exc from toolbox import update_ui, get_conf, trimmed_format_exc
from multiprocessing import Process, Pipe from multiprocessing import Process, Pipe
def preprocess_newbing_out(s): def preprocess_newbing_out(s):
pattern = r'\^(\d+)\^' # 匹配^数字^ pattern = r"\^(\d+)\^" # 匹配^数字^
sub = lambda m: '('+m.group(1)+')' # 将匹配到的数字作为替换值 sub = lambda m: "(" + m.group(1) + ")" # 将匹配到的数字作为替换值
result = re.sub(pattern, sub, s) # 替换操作 result = re.sub(pattern, sub, s) # 替换操作
if '[1]' in result: if "[1]" in result:
result += '\n\n```reference\n' + "\n".join([r for r in result.split('\n') if r.startswith('[')]) + '\n```\n' result += (
"\n\n```reference\n"
+ "\n".join([r for r in result.split("\n") if r.startswith("[")])
+ "\n```\n"
)
return result return result
def preprocess_newbing_out_simple(result): def preprocess_newbing_out_simple(result):
if '[1]' in result: if "[1]" in result:
result += '\n\n```reference\n' + "\n".join([r for r in result.split('\n') if r.startswith('[')]) + '\n```\n' result += (
"\n\n```reference\n"
+ "\n".join([r for r in result.split("\n") if r.startswith("[")])
+ "\n```\n"
)
return result return result
class NewBingHandle(Process): class NewBingHandle(Process):
def __init__(self): def __init__(self):
super().__init__(daemon=True) super().__init__(daemon=True)
@ -51,6 +63,7 @@ class NewBingHandle(Process):
try: try:
self.success = False self.success = False
import certifi, httpx, rich import certifi, httpx, rich
self.info = "依赖检测通过等待NewBing响应。注意目前不能多人同时调用NewBing接口有线程锁否则将导致每个人的NewBing问询历史互相渗透。调用NewBing时会自动使用已配置的代理。" self.info = "依赖检测通过等待NewBing响应。注意目前不能多人同时调用NewBing接口有线程锁否则将导致每个人的NewBing问询历史互相渗透。调用NewBing时会自动使用已配置的代理。"
self.success = True self.success = True
except: except:
@ -62,15 +75,16 @@ class NewBingHandle(Process):
async def async_run(self): async def async_run(self):
# 读取配置 # 读取配置
NEWBING_STYLE = get_conf('NEWBING_STYLE') NEWBING_STYLE = get_conf("NEWBING_STYLE")
from request_llms.bridge_all import model_info from request_llms.bridge_all import model_info
endpoint = model_info['newbing']['endpoint']
endpoint = model_info["newbing"]["endpoint"]
while True: while True:
# 等待 # 等待
kwargs = self.child.recv() kwargs = self.child.recv()
question=kwargs['query'] question = kwargs["query"]
history=kwargs['history'] history = kwargs["history"]
system_prompt=kwargs['system_prompt'] system_prompt = kwargs["system_prompt"]
# 是否重置 # 是否重置
if len(self.local_history) > 0 and len(history) == 0: if len(self.local_history) > 0 and len(history) == 0:
@ -81,19 +95,19 @@ class NewBingHandle(Process):
prompt = "" prompt = ""
if system_prompt not in self.local_history: if system_prompt not in self.local_history:
self.local_history.append(system_prompt) self.local_history.append(system_prompt)
prompt += system_prompt + '\n' prompt += system_prompt + "\n"
# 追加历史 # 追加历史
for ab in history: for ab in history:
a, b = ab a, b = ab
if a not in self.local_history: if a not in self.local_history:
self.local_history.append(a) self.local_history.append(a)
prompt += a + '\n' prompt += a + "\n"
# 问题 # 问题
prompt += question prompt += question
self.local_history.append(question) self.local_history.append(question)
print('question:', prompt) print("question:", prompt)
# 提交 # 提交
async for final, response in self.newbing_model.ask_stream( async for final, response in self.newbing_model.ask_stream(
prompt=question, prompt=question,
@ -104,11 +118,10 @@ class NewBingHandle(Process):
print(response) print(response)
self.child.send(str(response)) self.child.send(str(response))
else: else:
print('-------- receive final ---------') print("-------- receive final ---------")
self.child.send('[Finish]') self.child.send("[Finish]")
# self.local_history.append(response) # self.local_history.append(response)
def run(self): def run(self):
""" """
这个函数运行在子进程 这个函数运行在子进程
@ -118,32 +131,37 @@ class NewBingHandle(Process):
self.local_history = [] self.local_history = []
if (self.newbing_model is None) or (not self.success): if (self.newbing_model is None) or (not self.success):
# 代理设置 # 代理设置
proxies, NEWBING_COOKIES = get_conf('proxies', 'NEWBING_COOKIES') proxies, NEWBING_COOKIES = get_conf("proxies", "NEWBING_COOKIES")
if proxies is None: if proxies is None:
self.proxies_https = None self.proxies_https = None
else: else:
self.proxies_https = proxies['https'] self.proxies_https = proxies["https"]
if (NEWBING_COOKIES is not None) and len(NEWBING_COOKIES) > 100: if (NEWBING_COOKIES is not None) and len(NEWBING_COOKIES) > 100:
try: try:
cookies = json.loads(NEWBING_COOKIES) cookies = json.loads(NEWBING_COOKIES)
except: except:
self.success = False self.success = False
tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n' tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
self.child.send(f'[Local Message] NEWBING_COOKIES未填写或有格式错误。') self.child.send(f"[Local Message] NEWBING_COOKIES未填写或有格式错误。")
self.child.send('[Fail]'); self.child.send('[Finish]') self.child.send("[Fail]")
self.child.send("[Finish]")
raise RuntimeError(f"NEWBING_COOKIES未填写或有格式错误。") raise RuntimeError(f"NEWBING_COOKIES未填写或有格式错误。")
else: else:
cookies = None cookies = None
try: try:
self.newbing_model = NewbingChatbot(proxy=self.proxies_https, cookies=cookies) self.newbing_model = NewbingChatbot(
proxy=self.proxies_https, cookies=cookies
)
except: except:
self.success = False self.success = False
tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n' tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
self.child.send(f'[Local Message] 不能加载Newbing组件请注意Newbing组件已不再维护。{tb_str}') self.child.send(
self.child.send('[Fail]') f"[Local Message] 不能加载Newbing组件请注意Newbing组件已不再维护。{tb_str}"
self.child.send('[Finish]') )
self.child.send("[Fail]")
self.child.send("[Finish]")
raise RuntimeError(f"不能加载Newbing组件请注意Newbing组件已不再维护。") raise RuntimeError(f"不能加载Newbing组件请注意Newbing组件已不再维护。")
self.success = True self.success = True
@ -151,10 +169,12 @@ class NewBingHandle(Process):
# 进入任务等待状态 # 进入任务等待状态
asyncio.run(self.async_run()) asyncio.run(self.async_run())
except Exception: except Exception:
tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n' tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
self.child.send(f'[Local Message] Newbing 请求失败,报错信息如下. 如果是与网络相关的问题建议更换代理协议推荐http或代理节点 {tb_str}.') self.child.send(
self.child.send('[Fail]') f"[Local Message] Newbing 请求失败,报错信息如下. 如果是与网络相关的问题建议更换代理协议推荐http或代理节点 {tb_str}."
self.child.send('[Finish]') )
self.child.send("[Fail]")
self.child.send("[Finish]")
def stream_chat(self, **kwargs): def stream_chat(self, **kwargs):
""" """
@ -164,21 +184,33 @@ class NewBingHandle(Process):
self.parent.send(kwargs) # 请求子进程 self.parent.send(kwargs) # 请求子进程
while True: while True:
res = self.parent.recv() # 等待newbing回复的片段 res = self.parent.recv() # 等待newbing回复的片段
if res == '[Finish]': break # 结束 if res == "[Finish]":
elif res == '[Fail]': self.success = False; break # 失败 break # 结束
else: yield res # newbing回复的片段 elif res == "[Fail]":
self.success = False
break # 失败
else:
yield res # newbing回复的片段
self.threadLock.release() # 释放线程锁 self.threadLock.release() # 释放线程锁
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第三部分主进程统一调用函数接口 第三部分主进程统一调用函数接口
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
global newbingfree_handle global newbingfree_handle
newbingfree_handle = None newbingfree_handle = None
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
def predict_no_ui_long_connection(
inputs,
llm_kwargs,
history=[],
sys_prompt="",
observe_window=[],
console_slience=False,
):
""" """
多线程方法 多线程方法
函数的说明请见 request_llms/bridge_all.py 函数的说明请见 request_llms/bridge_all.py
@ -186,7 +218,8 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
global newbingfree_handle global newbingfree_handle
if (newbingfree_handle is None) or (not newbingfree_handle.success): if (newbingfree_handle is None) or (not newbingfree_handle.success):
newbingfree_handle = NewBingHandle() newbingfree_handle = NewBingHandle()
if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + newbingfree_handle.info if len(observe_window) >= 1:
observe_window[0] = load_message + "\n\n" + newbingfree_handle.info
if not newbingfree_handle.success: if not newbingfree_handle.success:
error = newbingfree_handle.info error = newbingfree_handle.info
newbingfree_handle = None newbingfree_handle = None
@ -199,15 +232,34 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可 watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
response = "" response = ""
if len(observe_window) >= 1: observe_window[0] = "[Local Message] 等待NewBing响应中 ..." if len(observe_window) >= 1:
for response in newbingfree_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=sys_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']): observe_window[0] = "[Local Message] 等待NewBing响应中 ..."
if len(observe_window) >= 1: observe_window[0] = preprocess_newbing_out_simple(response) for response in newbingfree_handle.stream_chat(
query=inputs,
history=history_feedin,
system_prompt=sys_prompt,
max_length=llm_kwargs["max_length"],
top_p=llm_kwargs["top_p"],
temperature=llm_kwargs["temperature"],
):
if len(observe_window) >= 1:
observe_window[0] = preprocess_newbing_out_simple(response)
if len(observe_window) >= 2: if len(observe_window) >= 2:
if (time.time() - observe_window[1]) > watch_dog_patience: if (time.time() - observe_window[1]) > watch_dog_patience:
raise RuntimeError("程序终止。") raise RuntimeError("程序终止。")
return preprocess_newbing_out_simple(response) return preprocess_newbing_out_simple(response)
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
def predict(
inputs,
llm_kwargs,
plugin_kwargs,
chatbot,
history=[],
system_prompt="",
stream=True,
additional_fn=None,
):
""" """
单线程方法 单线程方法
函数的说明请见 request_llms/bridge_all.py 函数的说明请见 request_llms/bridge_all.py
@ -225,7 +277,10 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
if additional_fn is not None: if additional_fn is not None:
from core_functional import handle_core_functionality from core_functional import handle_core_functionality
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
inputs, history = handle_core_functionality(
additional_fn, inputs, history, chatbot
)
history_feedin = [] history_feedin = []
for i in range(len(history) // 2): for i in range(len(history) // 2):
@ -233,13 +288,24 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
chatbot[-1] = (inputs, "[Local Message] 等待NewBing响应中 ...") chatbot[-1] = (inputs, "[Local Message] 等待NewBing响应中 ...")
response = "[Local Message] 等待NewBing响应中 ..." response = "[Local Message] 等待NewBing响应中 ..."
yield from update_ui(chatbot=chatbot, history=history, msg="NewBing响应缓慢尚未完成全部响应请耐心完成后再提交新问题。") yield from update_ui(
for response in newbingfree_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=system_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']): chatbot=chatbot, history=history, msg="NewBing响应缓慢尚未完成全部响应请耐心完成后再提交新问题。"
)
for response in newbingfree_handle.stream_chat(
query=inputs,
history=history_feedin,
system_prompt=system_prompt,
max_length=llm_kwargs["max_length"],
top_p=llm_kwargs["top_p"],
temperature=llm_kwargs["temperature"],
):
chatbot[-1] = (inputs, preprocess_newbing_out(response)) chatbot[-1] = (inputs, preprocess_newbing_out(response))
yield from update_ui(chatbot=chatbot, history=history, msg="NewBing响应缓慢尚未完成全部响应请耐心完成后再提交新问题。") yield from update_ui(
if response == "[Local Message] 等待NewBing响应中 ...": response = "[Local Message] NewBing响应异常请刷新界面重试 ..." chatbot=chatbot, history=history, msg="NewBing响应缓慢尚未完成全部响应请耐心完成后再提交新问题。"
)
if response == "[Local Message] 等待NewBing响应中 ...":
response = "[Local Message] NewBing响应异常请刷新界面重试 ..."
history.extend([inputs, response]) history.extend([inputs, response])
logging.info(f'[raw_input] {inputs}') logging.info(f"[raw_input] {inputs}")
logging.info(f'[response] {response}') logging.info(f"[response] {response}")
yield from update_ui(chatbot=chatbot, history=history, msg="完成全部响应,请提交新问题。") yield from update_ui(chatbot=chatbot, history=history, msg="完成全部响应,请提交新问题。")

View File

@ -7,14 +7,15 @@ import logging
import time import time
from toolbox import get_conf from toolbox import get_conf
import asyncio import asyncio
load_message = "正在加载Claude组件请稍候..." load_message = "正在加载Claude组件请稍候..."
try: try:
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第一部分Slack API Client 第一部分Slack API Client
https://github.com/yokonsan/claude-in-slack-api https://github.com/yokonsan/claude-in-slack-api
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
from slack_sdk.errors import SlackApiError from slack_sdk.errors import SlackApiError
@ -33,10 +34,13 @@ try:
- get_reply()异步方法循环监听已打开频道的消息如果收到"Typing…_"结尾的消息说明Claude还在继续输出否则结束循环 - get_reply()异步方法循环监听已打开频道的消息如果收到"Typing…_"结尾的消息说明Claude还在继续输出否则结束循环
""" """
CHANNEL_ID = None CHANNEL_ID = None
async def open_channel(self): async def open_channel(self):
response = await self.conversations_open(users=get_conf('SLACK_CLAUDE_BOT_ID')) response = await self.conversations_open(
users=get_conf("SLACK_CLAUDE_BOT_ID")
)
self.CHANNEL_ID = response["channel"]["id"] self.CHANNEL_ID = response["channel"]["id"]
async def chat(self, text): async def chat(self, text):
@ -49,9 +53,14 @@ try:
async def get_slack_messages(self): async def get_slack_messages(self):
try: try:
# TODO暂时不支持历史消息因为在同一个频道里存在多人使用时历史消息渗透问题 # TODO暂时不支持历史消息因为在同一个频道里存在多人使用时历史消息渗透问题
resp = await self.conversations_history(channel=self.CHANNEL_ID, oldest=self.LAST_TS, limit=1) resp = await self.conversations_history(
msg = [msg for msg in resp["messages"] channel=self.CHANNEL_ID, oldest=self.LAST_TS, limit=1
if msg.get("user") == get_conf('SLACK_CLAUDE_BOT_ID')] )
msg = [
msg
for msg in resp["messages"]
if msg.get("user") == get_conf("SLACK_CLAUDE_BOT_ID")
]
return msg return msg
except (SlackApiError, KeyError) as e: except (SlackApiError, KeyError) as e:
raise RuntimeError(f"获取Slack消息失败。") raise RuntimeError(f"获取Slack消息失败。")
@ -69,13 +78,14 @@ try:
else: else:
yield True, msg["text"] yield True, msg["text"]
break break
except: except:
pass pass
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第二部分子进程Worker调用主体 第二部分子进程Worker调用主体
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
@ -96,6 +106,7 @@ class ClaudeHandle(Process):
try: try:
self.success = False self.success = False
import slack_sdk import slack_sdk
self.info = "依赖检测通过等待Claude响应。注意目前不能多人同时调用Claude接口有线程锁否则将导致每个人的Claude问询历史互相渗透。调用Claude时会自动使用已配置的代理。" self.info = "依赖检测通过等待Claude响应。注意目前不能多人同时调用Claude接口有线程锁否则将导致每个人的Claude问询历史互相渗透。调用Claude时会自动使用已配置的代理。"
self.success = True self.success = True
except: except:
@ -110,15 +121,15 @@ class ClaudeHandle(Process):
while True: while True:
# 等待 # 等待
kwargs = self.child.recv() kwargs = self.child.recv()
question = kwargs['query'] question = kwargs["query"]
history = kwargs['history'] history = kwargs["history"]
# 开始问问题 # 开始问问题
prompt = "" prompt = ""
# 问题 # 问题
prompt += question prompt += question
print('question:', prompt) print("question:", prompt)
# 提交 # 提交
await self.claude_model.chat(prompt) await self.claude_model.chat(prompt)
@ -131,11 +142,15 @@ class ClaudeHandle(Process):
else: else:
# 防止丢失最后一条消息 # 防止丢失最后一条消息
slack_msgs = await self.claude_model.get_slack_messages() slack_msgs = await self.claude_model.get_slack_messages()
last_msg = slack_msgs[-1]["text"] if slack_msgs and len(slack_msgs) > 0 else "" last_msg = (
slack_msgs[-1]["text"]
if slack_msgs and len(slack_msgs) > 0
else ""
)
if last_msg: if last_msg:
self.child.send(last_msg) self.child.send(last_msg)
print('-------- receive final ---------') print("-------- receive final ---------")
self.child.send('[Finish]') self.child.send("[Finish]")
def run(self): def run(self):
""" """
@ -146,22 +161,24 @@ class ClaudeHandle(Process):
self.local_history = [] self.local_history = []
if (self.claude_model is None) or (not self.success): if (self.claude_model is None) or (not self.success):
# 代理设置 # 代理设置
proxies = get_conf('proxies') proxies = get_conf("proxies")
if proxies is None: if proxies is None:
self.proxies_https = None self.proxies_https = None
else: else:
self.proxies_https = proxies['https'] self.proxies_https = proxies["https"]
try: try:
SLACK_CLAUDE_USER_TOKEN = get_conf('SLACK_CLAUDE_USER_TOKEN') SLACK_CLAUDE_USER_TOKEN = get_conf("SLACK_CLAUDE_USER_TOKEN")
self.claude_model = SlackClient(token=SLACK_CLAUDE_USER_TOKEN, proxy=self.proxies_https) self.claude_model = SlackClient(
print('Claude组件初始化成功。') token=SLACK_CLAUDE_USER_TOKEN, proxy=self.proxies_https
)
print("Claude组件初始化成功。")
except: except:
self.success = False self.success = False
tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n' tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
self.child.send(f'[Local Message] 不能加载Claude组件。{tb_str}') self.child.send(f"[Local Message] 不能加载Claude组件。{tb_str}")
self.child.send('[Fail]') self.child.send("[Fail]")
self.child.send('[Finish]') self.child.send("[Finish]")
raise RuntimeError(f"不能加载Claude组件。") raise RuntimeError(f"不能加载Claude组件。")
self.success = True self.success = True
@ -169,10 +186,10 @@ class ClaudeHandle(Process):
# 进入任务等待状态 # 进入任务等待状态
asyncio.run(self.async_run()) asyncio.run(self.async_run())
except Exception: except Exception:
tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n' tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
self.child.send(f'[Local Message] Claude失败 {tb_str}.') self.child.send(f"[Local Message] Claude失败 {tb_str}.")
self.child.send('[Fail]') self.child.send("[Fail]")
self.child.send('[Finish]') self.child.send("[Finish]")
def stream_chat(self, **kwargs): def stream_chat(self, **kwargs):
""" """
@ -182,9 +199,9 @@ class ClaudeHandle(Process):
self.parent.send(kwargs) # 发送请求到子进程 self.parent.send(kwargs) # 发送请求到子进程
while True: while True:
res = self.parent.recv() # 等待Claude回复的片段 res = self.parent.recv() # 等待Claude回复的片段
if res == '[Finish]': if res == "[Finish]":
break # 结束 break # 结束
elif res == '[Fail]': elif res == "[Fail]":
self.success = False self.success = False
break break
else: else:
@ -193,15 +210,22 @@ class ClaudeHandle(Process):
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第三部分主进程统一调用函数接口 第三部分主进程统一调用函数接口
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
global claude_handle global claude_handle
claude_handle = None claude_handle = None
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False): def predict_no_ui_long_connection(
inputs,
llm_kwargs,
history=[],
sys_prompt="",
observe_window=None,
console_slience=False,
):
""" """
多线程方法 多线程方法
函数的说明请见 request_llms/bridge_all.py 函数的说明请见 request_llms/bridge_all.py
@ -223,7 +247,14 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可 watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
response = "" response = ""
observe_window[0] = "[Local Message] 等待Claude响应中 ..." observe_window[0] = "[Local Message] 等待Claude响应中 ..."
for response in claude_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=sys_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']): for response in claude_handle.stream_chat(
query=inputs,
history=history_feedin,
system_prompt=sys_prompt,
max_length=llm_kwargs["max_length"],
top_p=llm_kwargs["top_p"],
temperature=llm_kwargs["temperature"],
):
observe_window[0] = preprocess_newbing_out_simple(response) observe_window[0] = preprocess_newbing_out_simple(response)
if len(observe_window) >= 2: if len(observe_window) >= 2:
if (time.time() - observe_window[1]) > watch_dog_patience: if (time.time() - observe_window[1]) > watch_dog_patience:
@ -231,7 +262,16 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
return preprocess_newbing_out_simple(response) return preprocess_newbing_out_simple(response)
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None): def predict(
inputs,
llm_kwargs,
plugin_kwargs,
chatbot,
history=[],
system_prompt="",
stream=True,
additional_fn=None,
):
""" """
单线程方法 单线程方法
函数的说明请见 request_llms/bridge_all.py 函数的说明请见 request_llms/bridge_all.py
@ -249,7 +289,10 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
if additional_fn is not None: if additional_fn is not None:
from core_functional import handle_core_functionality from core_functional import handle_core_functionality
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
inputs, history = handle_core_functionality(
additional_fn, inputs, history, chatbot
)
history_feedin = [] history_feedin = []
for i in range(len(history) // 2): for i in range(len(history) // 2):
@ -257,13 +300,19 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
chatbot[-1] = (inputs, "[Local Message] 等待Claude响应中 ...") chatbot[-1] = (inputs, "[Local Message] 等待Claude响应中 ...")
response = "[Local Message] 等待Claude响应中 ..." response = "[Local Message] 等待Claude响应中 ..."
yield from update_ui(chatbot=chatbot, history=history, msg="Claude响应缓慢尚未完成全部响应请耐心完成后再提交新问题。") yield from update_ui(
for response in claude_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=system_prompt): chatbot=chatbot, history=history, msg="Claude响应缓慢尚未完成全部响应请耐心完成后再提交新问题。"
)
for response in claude_handle.stream_chat(
query=inputs, history=history_feedin, system_prompt=system_prompt
):
chatbot[-1] = (inputs, preprocess_newbing_out(response)) chatbot[-1] = (inputs, preprocess_newbing_out(response))
yield from update_ui(chatbot=chatbot, history=history, msg="Claude响应缓慢尚未完成全部响应请耐心完成后再提交新问题。") yield from update_ui(
chatbot=chatbot, history=history, msg="Claude响应缓慢尚未完成全部响应请耐心完成后再提交新问题。"
)
if response == "[Local Message] 等待Claude响应中 ...": if response == "[Local Message] 等待Claude响应中 ...":
response = "[Local Message] Claude响应异常请刷新界面重试 ..." response = "[Local Message] Claude响应异常请刷新界面重试 ..."
history.extend([inputs, response]) history.extend([inputs, response])
logging.info(f'[raw_input] {inputs}') logging.info(f"[raw_input] {inputs}")
logging.info(f'[response] {response}') logging.info(f"[response] {response}")
yield from update_ui(chatbot=chatbot, history=history, msg="完成全部响应,请提交新问题。") yield from update_ui(chatbot=chatbot, history=history, msg="完成全部响应,请提交新问题。")

View File

@ -12,7 +12,7 @@ from toolbox import get_conf, encode_image, get_pictures_list
proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS") proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第五部分 一些文件处理方法 第五部分 一些文件处理方法
files_filter_handler 根据type过滤文件 files_filter_handler 根据type过滤文件
input_encode_handler 提取input中的文件并解析 input_encode_handler 提取input中的文件并解析
@ -21,6 +21,7 @@ link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
html_view_blank 超链接 html_view_blank 超链接
html_local_file 本地文件取相对路径 html_local_file 本地文件取相对路径
to_markdown_tabs 文件list 转换为 md tab to_markdown_tabs 文件list 转换为 md tab
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """

View File

@ -1,8 +1,8 @@
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第一部分来自EdgeGPT.py 第一部分来自EdgeGPT.py
https://github.com/acheong08/EdgeGPT https://github.com/acheong08/EdgeGPT
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
""" """
Main.py Main.py
@ -452,9 +452,11 @@ class _ChatHub:
ws_cookies = [] ws_cookies = []
for cookie in self.cookies: for cookie in self.cookies:
ws_cookies.append(f"{cookie['name']}={cookie['value']}") ws_cookies.append(f"{cookie['name']}={cookie['value']}")
req_header.update({ req_header.update(
'Cookie': ';'.join(ws_cookies), {
}) "Cookie": ";".join(ws_cookies),
}
)
timeout = aiohttp.ClientTimeout(total=30) timeout = aiohttp.ClientTimeout(total=30)
self.session = aiohttp.ClientSession(timeout=timeout) self.session = aiohttp.ClientSession(timeout=timeout)

View File

@ -2,6 +2,7 @@ import markdown
import re import re
import os import os
import math import math
from textwrap import dedent
from latex2mathml.converter import convert as tex2mathml from latex2mathml.converter import convert as tex2mathml
from functools import wraps, lru_cache from functools import wraps, lru_cache
from shared_utils.config_loader import get_conf as get_conf from shared_utils.config_loader import get_conf as get_conf
@ -32,26 +33,6 @@ def text_divide_paragraph(text):
text = "</br>".join(lines) text = "</br>".join(lines)
return pre + text + suf return pre + text + suf
@lru_cache(maxsize=128) # 使用 lru缓存 加快转换速度
def markdown_convertion(txt):
"""
将Markdown格式的文本转换为HTML格式如果包含数学公式则先将公式转换为HTML格式
"""
pre = '<div class="markdown-body">'
suf = '</div>'
if txt.startswith(pre) and txt.endswith(suf):
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
return txt # 已经被转化过,不需要再次转化
markdown_extension_configs = {
'mdx_math': {
'enable_dollar_delimiter': True,
'use_gitlab_delimiters': False,
},
}
find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
def tex2mathml_catch_exception(content, *args, **kwargs): def tex2mathml_catch_exception(content, *args, **kwargs):
try: try:
content = tex2mathml(content, *args, **kwargs) content = tex2mathml(content, *args, **kwargs)
@ -121,7 +102,8 @@ def markdown_convertion(txt):
def fix_markdown_indent(txt): def fix_markdown_indent(txt):
# fix markdown indent # fix markdown indent
if (' - ' not in txt) or ('. ' not in txt): if (' - ' not in txt) or ('. ' not in txt):
return txt # do not need to fix, fast escape # do not need to fix, fast escape
return txt
# walk through the lines and fix non-standard indentation # walk through the lines and fix non-standard indentation
lines = txt.split("\n") lines = txt.split("\n")
pattern = re.compile(r'^\s+-') pattern = re.compile(r'^\s+-')
@ -137,7 +119,83 @@ def markdown_convertion(txt):
lines[i] = ' ' * num_spaces_should_be + stripped_string lines[i] = ' ' * num_spaces_should_be + stripped_string
return '\n'.join(lines) return '\n'.join(lines)
FENCED_BLOCK_RE = re.compile(
dedent(r'''
(?P<fence>^[ \t]*(?:~{3,}|`{3,}))[ ]* # opening fence
((\{(?P<attrs>[^\}\n]*)\})| # (optional {attrs} or
(\.?(?P<lang>[\w#.+-]*)[ ]*)? # optional (.)lang
(hl_lines=(?P<quot>"|')(?P<hl_lines>.*?)(?P=quot)[ ]*)?) # optional hl_lines)
\n # newline (end of opening fence)
(?P<code>.*?)(?<=\n) # the code block
(?P=fence)[ ]*$ # closing fence
'''),
re.MULTILINE | re.DOTALL | re.VERBOSE
)
def get_line_range(re_match_obj, txt):
start_pos, end_pos = re_match_obj.regs[0]
num_newlines_before = txt[:start_pos+1].count('\n')
line_start = num_newlines_before
line_end = num_newlines_before + txt[start_pos:end_pos].count('\n')+1
return line_start, line_end
def fix_code_segment_indent(txt):
lines = []
change_any = False
txt_tmp = txt
while True:
re_match_obj = FENCED_BLOCK_RE.search(txt_tmp)
if not re_match_obj: break
if len(lines) == 0: lines = txt.split("\n")
# 清空 txt_tmp 对应的位置方便下次搜索
start_pos, end_pos = re_match_obj.regs[0]
txt_tmp = txt_tmp[:start_pos] + ' '*(end_pos-start_pos) + txt_tmp[end_pos:]
line_start, line_end = get_line_range(re_match_obj, txt)
# 获取公共缩进
shared_indent_cnt = 1e5
for i in range(line_start, line_end):
stripped_string = lines[i].lstrip()
num_spaces = len(lines[i]) - len(stripped_string)
if num_spaces < shared_indent_cnt:
shared_indent_cnt = num_spaces
# 修复缩进
if (shared_indent_cnt < 1e5) and (shared_indent_cnt % 4) == 3:
num_spaces_should_be = math.ceil(shared_indent_cnt / 4) * 4
for i in range(line_start, line_end):
add_n = num_spaces_should_be - shared_indent_cnt
lines[i] = ' ' * add_n + lines[i]
if not change_any: # 遇到第一个
change_any = True
if change_any:
return '\n'.join(lines)
else:
return txt
@lru_cache(maxsize=128) # 使用 lru缓存 加快转换速度
def markdown_convertion(txt):
"""
将Markdown格式的文本转换为HTML格式如果包含数学公式则先将公式转换为HTML格式
"""
pre = '<div class="markdown-body">'
suf = '</div>'
if txt.startswith(pre) and txt.endswith(suf):
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
return txt # 已经被转化过,不需要再次转化
markdown_extension_configs = {
'mdx_math': {
'enable_dollar_delimiter': True,
'use_gitlab_delimiters': False,
},
}
find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
txt = fix_markdown_indent(txt) txt = fix_markdown_indent(txt)
txt = fix_code_segment_indent(txt)
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识 if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
# convert everything to html format # convert everything to html format
split = markdown.markdown(text='---') split = markdown.markdown(text='---')

View File

@ -1,7 +1,7 @@
import os import os
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
接驳void-terminal: 接驳void-terminal:
- set_conf: 在运行过程中动态地修改配置 - set_conf: 在运行过程中动态地修改配置
- set_multi_conf: 在运行过程中动态地修改多个配置 - set_multi_conf: 在运行过程中动态地修改多个配置
@ -9,17 +9,20 @@ import os
- get_plugin_default_kwargs: 获取插件的默认参数 - get_plugin_default_kwargs: 获取插件的默认参数
- get_chat_handle: 获取简单聊天的句柄 - get_chat_handle: 获取简单聊天的句柄
- get_chat_default_kwargs: 获取简单聊天的默认参数 - get_chat_default_kwargs: 获取简单聊天的默认参数
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
def get_plugin_handle(plugin_name): def get_plugin_handle(plugin_name):
""" """
e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言' e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言'
""" """
import importlib import importlib
assert '->' in plugin_name, \
"Example of plugin_name: crazy_functions.批量Markdown翻译->Markdown翻译指定语言" assert (
module, fn_name = plugin_name.split('->') "->" in plugin_name
), "Example of plugin_name: crazy_functions.批量Markdown翻译->Markdown翻译指定语言"
module, fn_name = plugin_name.split("->")
f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name) f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
return f_hot_reload return f_hot_reload
@ -29,6 +32,7 @@ def get_chat_handle():
Get chat function Get chat function
""" """
from request_llms.bridge_all import predict_no_ui_long_connection from request_llms.bridge_all import predict_no_ui_long_connection
return predict_no_ui_long_connection return predict_no_ui_long_connection
@ -37,13 +41,14 @@ def get_plugin_default_kwargs():
Get Plugin Default Arguments Get Plugin Default Arguments
""" """
from toolbox import ChatBotWithCookies, load_chat_cookies from toolbox import ChatBotWithCookies, load_chat_cookies
cookies = load_chat_cookies() cookies = load_chat_cookies()
llm_kwargs = { llm_kwargs = {
'api_key': cookies['api_key'], "api_key": cookies["api_key"],
'llm_model': cookies['llm_model'], "llm_model": cookies["llm_model"],
'top_p': 1.0, "top_p": 1.0,
'max_length': None, "max_length": None,
'temperature': 1.0, "temperature": 1.0,
} }
chatbot = ChatBotWithCookies(llm_kwargs) chatbot = ChatBotWithCookies(llm_kwargs)
@ -55,7 +60,7 @@ def get_plugin_default_kwargs():
"chatbot_with_cookie": chatbot, "chatbot_with_cookie": chatbot,
"history": [], "history": [],
"system_prompt": "You are a good AI.", "system_prompt": "You are a good AI.",
"web_port": None "web_port": None,
} }
return DEFAULT_FN_GROUPS_kwargs return DEFAULT_FN_GROUPS_kwargs
@ -65,13 +70,14 @@ def get_chat_default_kwargs():
Get Chat Default Arguments Get Chat Default Arguments
""" """
from toolbox import load_chat_cookies from toolbox import load_chat_cookies
cookies = load_chat_cookies() cookies = load_chat_cookies()
llm_kwargs = { llm_kwargs = {
'api_key': cookies['api_key'], "api_key": cookies["api_key"],
'llm_model': cookies['llm_model'], "llm_model": cookies["llm_model"],
'top_p': 1.0, "top_p": 1.0,
'max_length': None, "max_length": None,
'temperature': 1.0, "temperature": 1.0,
} }
default_chat_kwargs = { default_chat_kwargs = {
"inputs": "Hello there, are you ready?", "inputs": "Hello there, are you ready?",

View File

@ -1,32 +1,75 @@
md = """ md = """
作为您的写作和编程助手我可以为您提供以下服务
1. 写作 要计算文件的哈希值可以使用哈希算法如MD5SHA-1或SHA-256对文件的内容进行计算
- 帮助您撰写文章报告散文故事等
- 提供写作建议和技巧
- 协助您进行文案策划和内容创作
2. 编程 以下是一个使用sha256算法计算文件哈希值的示例代码
- 帮助您解决编程问题提供编程思路和建议
- 协助您编写代码包括但不限于 PythonJavaC++
- 为您解释复杂的技术概念让您更容易理解
3. 项目支持 ```python
- 协助您规划项目进度和任务分配 import hashlib
- 提供项目管理和协作建议
- 在项目实施过程中提供支持确保项目顺利进行 def calculate_hash(file_path):
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as file:
for chunk in iter(lambda: file.read(4096), b''):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
# 使用示例
file_path = 'path/to/file.txt'
hash_value = calculate_hash(file_path)
print('File hash:', hash_value)
```
在上面的示例中`calculate_hash`函数接受一个文件路径作为参数并打开文件以二进制读取模式读取文件内容然后使用哈希对象sha256初始化并对文件内容进行分块读取并更新哈希值最后通过`hexdigest`方法获取哈希值的十六进制表示
可以根据需要更改哈希算法如使用`hashlib.md5()`来使用MD5算法和块大小这里使用4096字节
"""
md = """
要在Ubuntu中将NTFS格式转换为ext4格式您需要进行以下步骤
1. 首先确保您已经安装了gparted软件如果没有安装请使用以下命令进行安装
```
sudo apt update
sudo apt install gparted
```
2. 然后打开GParted软件您可以在"应用程序"菜单中搜索并启动它
3. 在GParted界面中选择您想要转换格式的NTFS分区请小心选择确保选择正确的分区
4. 确保分区未挂载如果分区当前正在使用您需要首先卸载它在命令行中您可以使用以下命令卸载该分区
```
sudo umount /dev/sdc1
```
注意请将"/dev/sdc1"替换为您要卸载的分区的正确路径
5. 在GParted界面中单击菜单中的"设备"选项然后选择"创建"
6. 在弹出的对话框中选择要转换为的文件系统类型在这种情况下选择"ext4"然后单击"添加"按钮
7. "操作"菜单中选择"应用所有操作"这将开始分区格式转换的过程
8. 等待GParted完成转换操作这可能需要一些时间具体取决于分区的大小和系统性能
9. 转换完成后您将看到分区的文件系统已更改为ext4
10. 最后请确保挂载分区以便访问它您可以使用以下命令挂载该分区
```
sudo mount /dev/sdc1 /media/fuqingxu/eb63a8fa-cee9-48a5-9f05-b1388c3fda9e
```
注意请将"/dev/sdc1"替换为已转换分区的正确路径并将"/media/fuqingxu/eb63a8fa-cee9-48a5-9f05-b1388c3fda9e"替换为您要挂载的目标路径
请注意在执行任何分区操作之前务必备份重要的数据操作不当可能导致数据丢失
4. 学习辅导
- 帮助您巩固编程基础提高编程能力
- 提供计算机科学数据科学人工智能等相关领域的学习资源和建议
- 解答您在学习过程中遇到的问题让您更好地掌握知识
5. 行业动态和趋势分析
- 为您提供业界最新的新闻和技术趋势
- 分析行业动态帮助您了解市场发展和竞争态势
- 为您制定技术战略提供参考和建议
请随时告诉我您的需求我会尽力提供帮助如果您有任何问题或需要解答的议题请随时提问
""" """
@ -43,6 +86,6 @@ validate_path() # validate path so you can run from base directory
from toolbox import markdown_convertion from toolbox import markdown_convertion
html = markdown_convertion(md) html = markdown_convertion(md)
print(html) # print(html)
with open("test.html", "w", encoding="utf-8") as f: with open("test.html", "w", encoding="utf-8") as f:
f.write(html) f.write(html)

View File

@ -11,6 +11,7 @@ from functools import wraps
from shared_utils.config_loader import get_conf from shared_utils.config_loader import get_conf
from shared_utils.config_loader import set_conf from shared_utils.config_loader import set_conf
from shared_utils.advanced_markdown_format import format_io from shared_utils.advanced_markdown_format import format_io
from shared_utils.advanced_markdown_format import markdown_convertion
from shared_utils.key_pattern_manager import select_api_key from shared_utils.key_pattern_manager import select_api_key
from shared_utils.key_pattern_manager import is_any_api_key from shared_utils.key_pattern_manager import is_any_api_key
from shared_utils.key_pattern_manager import what_keys from shared_utils.key_pattern_manager import what_keys
@ -20,10 +21,10 @@ from shared_utils.connect_void_terminal import get_plugin_default_kwargs
from shared_utils.connect_void_terminal import get_chat_default_kwargs from shared_utils.connect_void_terminal import get_chat_default_kwargs
pj = os.path.join pj = os.path.join
default_user_name = 'default_user' default_user_name = "default_user"
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第一部分 第一部分
函数插件输入输出接驳区 函数插件输入输出接驳区
- ChatBotWithCookies: 带Cookies的Chatbot类为实现更多强大的功能做基础 - ChatBotWithCookies: 带Cookies的Chatbot类为实现更多强大的功能做基础
@ -32,7 +33,7 @@ default_user_name = 'default_user'
- CatchException: 将插件中出的所有问题显示在界面上 - CatchException: 将插件中出的所有问题显示在界面上
- HotReload: 实现插件的热更新 - HotReload: 实现插件的热更新
- trimmed_format_exc: 打印traceback为了安全而隐藏绝对地址 - trimmed_format_exc: 打印traceback为了安全而隐藏绝对地址
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
@ -120,22 +121,30 @@ def ArgsGeneralWrapper(f):
return decorated return decorated
def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面 def update_ui(chatbot, history, msg="正常", **kwargs): # 刷新界面
""" """
刷新用户界面 刷新用户界面
""" """
assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时, 可用clear将其清空, 然后用for+append循环重新赋值。" assert isinstance(
chatbot, ChatBotWithCookies
), "在传递chatbot的过程中不要将其丢弃。必要时, 可用clear将其清空, 然后用for+append循环重新赋值。"
cookies = chatbot.get_cookies() cookies = chatbot.get_cookies()
# 备份一份History作为记录 # 备份一份History作为记录
cookies.update({'history': history}) cookies.update({"history": history})
# 解决插件锁定时的界面显示问题 # 解决插件锁定时的界面显示问题
if cookies.get('lock_plugin', None): if cookies.get("lock_plugin", None):
label = cookies.get('llm_model', "") + " | " + "正在锁定插件" + cookies.get('lock_plugin', None) label = (
cookies.get("llm_model", "")
+ " | "
+ "正在锁定插件"
+ cookies.get("lock_plugin", None)
)
chatbot_gr = gradio.update(value=chatbot, label=label) chatbot_gr = gradio.update(value=chatbot, label=label)
if cookies.get('label', "") != label: cookies['label'] = label # 记住当前的label if cookies.get("label", "") != label:
elif cookies.get('label', None): cookies["label"] = label # 记住当前的label
chatbot_gr = gradio.update(value=chatbot, label=cookies.get('llm_model', "")) elif cookies.get("label", None):
cookies['label'] = None # 清空label chatbot_gr = gradio.update(value=chatbot, label=cookies.get("llm_model", ""))
cookies["label"] = None # 清空label
else: else:
chatbot_gr = chatbot chatbot_gr = chatbot
@ -146,7 +155,8 @@ def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面
""" """
刷新用户界面 刷新用户界面
""" """
if len(chatbot) == 0: chatbot.append(["update_ui_last_msg", lastmsg]) if len(chatbot) == 0:
chatbot.append(["update_ui_last_msg", lastmsg])
chatbot[-1] = list(chatbot[-1]) chatbot[-1] = list(chatbot[-1])
chatbot[-1][-1] = lastmsg chatbot[-1][-1] = lastmsg
yield from update_ui(chatbot=chatbot, history=history) yield from update_ui(chatbot=chatbot, history=history)
@ -155,6 +165,7 @@ def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面
def trimmed_format_exc(): def trimmed_format_exc():
import os, traceback import os, traceback
str = traceback.format_exc() str = traceback.format_exc()
current_path = os.getcwd() current_path = os.getcwd()
replace_path = "." replace_path = "."
@ -194,19 +205,21 @@ def HotReload(f):
最后使用yield from语句返回重新加载过的函数并在被装饰的函数上执行 最后使用yield from语句返回重新加载过的函数并在被装饰的函数上执行
最终装饰器函数返回内部函数这个内部函数可以将函数的原始定义更新为最新版本并执行函数的新版本 最终装饰器函数返回内部函数这个内部函数可以将函数的原始定义更新为最新版本并执行函数的新版本
""" """
if get_conf('PLUGIN_HOT_RELOAD'): if get_conf("PLUGIN_HOT_RELOAD"):
@wraps(f) @wraps(f)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
fn_name = f.__name__ fn_name = f.__name__
f_hot_reload = getattr(importlib.reload(inspect.getmodule(f)), fn_name) f_hot_reload = getattr(importlib.reload(inspect.getmodule(f)), fn_name)
yield from f_hot_reload(*args, **kwargs) yield from f_hot_reload(*args, **kwargs)
return decorated return decorated
else: else:
return f return f
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第二部分 第二部分
其他小工具: 其他小工具:
- write_history_to_file: 将结果写入markdown文件中 - write_history_to_file: 将结果写入markdown文件中
@ -220,7 +233,7 @@ def HotReload(f):
- clip_history: 当历史上下文过长时自动截断 - clip_history: 当历史上下文过长时自动截断
- get_conf: 获取设置 - get_conf: 获取设置
- select_api_key: 根据当前的模型类别抽取可用的api-key - select_api_key: 根据当前的模型类别抽取可用的api-key
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
@ -239,36 +252,40 @@ def get_reduce_token_percent(text):
assert ratio > 0 and ratio < 1 assert ratio > 0 and ratio < 1
return ratio, str(int(current_tokens - max_limit)) return ratio, str(int(current_tokens - max_limit))
except: except:
return 0.5, '不详' return 0.5, "不详"
def write_history_to_file(history, file_basename=None, file_fullname=None, auto_caption=True): def write_history_to_file(
history, file_basename=None, file_fullname=None, auto_caption=True
):
""" """
将对话记录history以Markdown格式写入文件中如果没有指定文件名则使用当前时间生成文件名 将对话记录history以Markdown格式写入文件中如果没有指定文件名则使用当前时间生成文件名
""" """
import os import os
import time import time
if file_fullname is None: if file_fullname is None:
if file_basename is not None: if file_basename is not None:
file_fullname = pj(get_log_folder(), file_basename) file_fullname = pj(get_log_folder(), file_basename)
else: else:
file_fullname = pj(get_log_folder(), f'GPT-Academic-{gen_time_str()}.md') file_fullname = pj(get_log_folder(), f"GPT-Academic-{gen_time_str()}.md")
os.makedirs(os.path.dirname(file_fullname), exist_ok=True) os.makedirs(os.path.dirname(file_fullname), exist_ok=True)
with open(file_fullname, 'w', encoding='utf8') as f: with open(file_fullname, "w", encoding="utf8") as f:
f.write('# GPT-Academic Report\n') f.write("# GPT-Academic Report\n")
for i, content in enumerate(history): for i, content in enumerate(history):
try: try:
if type(content) != str: content = str(content) if type(content) != str:
content = str(content)
except: except:
continue continue
if i % 2 == 0 and auto_caption: if i % 2 == 0 and auto_caption:
f.write('## ') f.write("## ")
try: try:
f.write(content) f.write(content)
except: except:
# remove everything that cannot be handled by utf8 # remove everything that cannot be handled by utf8
f.write(content.encode('utf-8', 'ignore').decode()) f.write(content.encode("utf-8", "ignore").decode())
f.write('\n\n') f.write("\n\n")
res = os.path.abspath(file_fullname) res = os.path.abspath(file_fullname)
return res return res
@ -277,9 +294,9 @@ def regular_txt_to_markdown(text):
""" """
将普通文本转换为Markdown格式的文本 将普通文本转换为Markdown格式的文本
""" """
text = text.replace('\n', '\n\n') text = text.replace("\n", "\n\n")
text = text.replace('\n\n\n', '\n\n') text = text.replace("\n\n\n", "\n\n")
text = text.replace('\n\n\n', '\n\n') text = text.replace("\n\n\n", "\n\n")
return text return text
@ -297,8 +314,9 @@ def find_free_port():
""" """
import socket import socket
from contextlib import closing from contextlib import closing
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0)) s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1] return s.getsockname()[1]
@ -307,45 +325,48 @@ def extract_archive(file_path, dest_dir):
import zipfile import zipfile
import tarfile import tarfile
import os import os
# Get the file extension of the input file # Get the file extension of the input file
file_extension = os.path.splitext(file_path)[1] file_extension = os.path.splitext(file_path)[1]
# Extract the archive based on its extension # Extract the archive based on its extension
if file_extension == '.zip': if file_extension == ".zip":
with zipfile.ZipFile(file_path, 'r') as zipobj: with zipfile.ZipFile(file_path, "r") as zipobj:
zipobj.extractall(path=dest_dir) zipobj.extractall(path=dest_dir)
print("Successfully extracted zip archive to {}".format(dest_dir)) print("Successfully extracted zip archive to {}".format(dest_dir))
elif file_extension in ['.tar', '.gz', '.bz2']: elif file_extension in [".tar", ".gz", ".bz2"]:
with tarfile.open(file_path, 'r:*') as tarobj: with tarfile.open(file_path, "r:*") as tarobj:
tarobj.extractall(path=dest_dir) tarobj.extractall(path=dest_dir)
print("Successfully extracted tar archive to {}".format(dest_dir)) print("Successfully extracted tar archive to {}".format(dest_dir))
# 第三方库需要预先pip install rarfile # 第三方库需要预先pip install rarfile
# 此外Windows上还需要安装winrar软件配置其Path环境变量如"C:\Program Files\WinRAR"才可以 # 此外Windows上还需要安装winrar软件配置其Path环境变量如"C:\Program Files\WinRAR"才可以
elif file_extension == '.rar': elif file_extension == ".rar":
try: try:
import rarfile import rarfile
with rarfile.RarFile(file_path) as rf: with rarfile.RarFile(file_path) as rf:
rf.extractall(path=dest_dir) rf.extractall(path=dest_dir)
print("Successfully extracted rar archive to {}".format(dest_dir)) print("Successfully extracted rar archive to {}".format(dest_dir))
except: except:
print("Rar format requires additional dependencies to install") print("Rar format requires additional dependencies to install")
return '\n\n解压失败! 需要安装pip install rarfile来解压rar文件。建议使用zip压缩格式。' return "\n\n解压失败! 需要安装pip install rarfile来解压rar文件。建议使用zip压缩格式。"
# 第三方库需要预先pip install py7zr # 第三方库需要预先pip install py7zr
elif file_extension == '.7z': elif file_extension == ".7z":
try: try:
import py7zr import py7zr
with py7zr.SevenZipFile(file_path, mode='r') as f:
with py7zr.SevenZipFile(file_path, mode="r") as f:
f.extractall(path=dest_dir) f.extractall(path=dest_dir)
print("Successfully extracted 7z archive to {}".format(dest_dir)) print("Successfully extracted 7z archive to {}".format(dest_dir))
except: except:
print("7z format requires additional dependencies to install") print("7z format requires additional dependencies to install")
return '\n\n解压失败! 需要安装pip install py7zr来解压7z文件' return "\n\n解压失败! 需要安装pip install py7zr来解压7z文件"
else: else:
return '' return ""
return '' return ""
def find_recent_files(directory): def find_recent_files(directory):
@ -355,6 +376,7 @@ def find_recent_files(directory):
""" """
import os import os
import time import time
current_time = time.time() current_time = time.time()
one_minute_ago = current_time - 60 one_minute_ago = current_time - 60
recent_files = [] recent_files = []
@ -362,7 +384,7 @@ def find_recent_files(directory):
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
for filename in os.listdir(directory): for filename in os.listdir(directory):
file_path = pj(directory, filename) file_path = pj(directory, filename)
if file_path.endswith('.log'): if file_path.endswith(".log"):
continue continue
created_time = os.path.getmtime(file_path) created_time = os.path.getmtime(file_path)
if created_time >= one_minute_ago: if created_time >= one_minute_ago:
@ -388,49 +410,53 @@ def file_already_in_downloadzone(file, user_path):
def promote_file_to_downloadzone(file, rename_file=None, chatbot=None): def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
# 将文件复制一份到下载区 # 将文件复制一份到下载区
import shutil import shutil
if chatbot is not None: if chatbot is not None:
user_name = get_user(chatbot) user_name = get_user(chatbot)
else: else:
user_name = default_user_name user_name = default_user_name
if not os.path.exists(file): if not os.path.exists(file):
raise FileNotFoundError(f'文件{file}不存在') raise FileNotFoundError(f"文件{file}不存在")
user_path = get_log_folder(user_name, plugin_name=None) user_path = get_log_folder(user_name, plugin_name=None)
if file_already_in_downloadzone(file, user_path): if file_already_in_downloadzone(file, user_path):
new_path = file new_path = file
else: else:
user_path = get_log_folder(user_name, plugin_name='downloadzone') user_path = get_log_folder(user_name, plugin_name="downloadzone")
if rename_file is None: rename_file = f'{gen_time_str()}-{os.path.basename(file)}' if rename_file is None:
rename_file = f"{gen_time_str()}-{os.path.basename(file)}"
new_path = pj(user_path, rename_file) new_path = pj(user_path, rename_file)
# 如果已经存在,先删除 # 如果已经存在,先删除
if os.path.exists(new_path) and not os.path.samefile(new_path, file): os.remove(new_path) if os.path.exists(new_path) and not os.path.samefile(new_path, file):
os.remove(new_path)
# 把文件复制过去 # 把文件复制过去
if not os.path.exists(new_path): shutil.copyfile(file, new_path) if not os.path.exists(new_path):
shutil.copyfile(file, new_path)
# 将文件添加到chatbot cookie中 # 将文件添加到chatbot cookie中
if chatbot is not None: if chatbot is not None:
if 'files_to_promote' in chatbot._cookies: if "files_to_promote" in chatbot._cookies:
current = chatbot._cookies['files_to_promote'] current = chatbot._cookies["files_to_promote"]
else: else:
current = [] current = []
if new_path not in current: # 避免把同一个文件添加多次 if new_path not in current: # 避免把同一个文件添加多次
chatbot._cookies.update({'files_to_promote': [new_path] + current}) chatbot._cookies.update({"files_to_promote": [new_path] + current})
return new_path return new_path
def disable_auto_promotion(chatbot): def disable_auto_promotion(chatbot):
chatbot._cookies.update({'files_to_promote': []}) chatbot._cookies.update({"files_to_promote": []})
return return
def del_outdated_uploads(outdate_time_seconds, target_path_base=None): def del_outdated_uploads(outdate_time_seconds, target_path_base=None):
if target_path_base is None: if target_path_base is None:
user_upload_dir = get_conf('PATH_PRIVATE_UPLOAD') user_upload_dir = get_conf("PATH_PRIVATE_UPLOAD")
else: else:
user_upload_dir = target_path_base user_upload_dir = target_path_base
current_time = time.time() current_time = time.time()
one_hour_ago = current_time - outdate_time_seconds one_hour_ago = current_time - outdate_time_seconds
# Get a list of all subdirectories in the user_upload_dir folder # Get a list of all subdirectories in the user_upload_dir folder
# Remove subdirectories that are older than one hour # Remove subdirectories that are older than one hour
for subdirectory in glob.glob(f'{user_upload_dir}/*'): for subdirectory in glob.glob(f"{user_upload_dir}/*"):
subdirectory_time = os.path.getmtime(subdirectory) subdirectory_time = os.path.getmtime(subdirectory)
if subdirectory_time < one_hour_ago: if subdirectory_time < one_hour_ago:
try: try:
@ -447,8 +473,8 @@ def html_local_file(file):
return file return file
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True): def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
style = '' style = ""
if max_width is not None: if max_width is not None:
style += f"max-width: {max_width};" style += f"max-width: {max_width};"
if max_height is not None: if max_height is not None:
@ -456,20 +482,23 @@ def html_local_img(__file, layout='left', max_width=None, max_height=None, md=Tr
__file = html_local_file(__file) __file = html_local_file(__file)
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>' a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
if md: if md:
a = f'![{__file}]({__file})' a = f"![{__file}]({__file})"
return a return a
def file_manifest_filter_type(file_list, filter_: list = None): def file_manifest_filter_type(file_list, filter_: list = None):
new_list = [] new_list = []
if not filter_: filter_ = ['png', 'jpg', 'jpeg'] if not filter_:
filter_ = ["png", "jpg", "jpeg"]
for file in file_list: for file in file_list:
if str(os.path.basename(file)).split('.')[-1] in filter_: if str(os.path.basename(file)).split(".")[-1] in filter_:
new_list.append(html_local_img(file, md=False)) new_list.append(html_local_img(file, md=False))
else: else:
new_list.append(file) new_list.append(file)
return new_list return new_list
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False):
""" """
Args: Args:
head: 表头[] head: 表头[]
@ -487,17 +516,20 @@ def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
max_len = max(len(column) for column in transposed_tabs) max_len = max(len(column) for column in transposed_tabs)
tab_format = "| %s " tab_format = "| %s "
tabs_list = "".join([tab_format % i for i in head]) + '|\n' tabs_list = "".join([tab_format % i for i in head]) + "|\n"
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n' tabs_list += "".join([tab_format % alignment for i in head]) + "|\n"
for i in range(max_len): for i in range(max_len):
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs] row_data = [tab[i] if i < len(tab) else "" for tab in transposed_tabs]
row_data = file_manifest_filter_type(row_data, filter_=None) row_data = file_manifest_filter_type(row_data, filter_=None)
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n' tabs_list += "".join([tab_format % i for i in row_data]) + "|\n"
return tabs_list return tabs_list
def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkboxes, cookies):
def on_file_uploaded(
request: gradio.Request, files, chatbot, txt, txt2, checkboxes, cookies
):
""" """
当文件被上传时的回调函数 当文件被上传时的回调函数
""" """
@ -515,94 +547,118 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name)) del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name))
# 逐个文件转移到目标路径 # 逐个文件转移到目标路径
upload_msg = '' upload_msg = ""
for file in files: for file in files:
file_origin_name = os.path.basename(file.orig_name) file_origin_name = os.path.basename(file.orig_name)
this_file_path = pj(target_path_base, file_origin_name) this_file_path = pj(target_path_base, file_origin_name)
shutil.move(file.name, this_file_path) shutil.move(file.name, this_file_path)
upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path + '.extract') upload_msg += extract_archive(
file_path=this_file_path, dest_dir=this_file_path + ".extract"
)
# 整理文件集合 输出消息 # 整理文件集合 输出消息
moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)] moved_files = [fp for fp in glob.glob(f"{target_path_base}/**/*", recursive=True)]
moved_files_str = to_markdown_tabs(head=['文件'], tabs=[moved_files]) moved_files_str = to_markdown_tabs(head=["文件"], tabs=[moved_files])
chatbot.append(['我上传了文件,请查收', chatbot.append(
f'[Local Message] 收到以下文件: \n\n{moved_files_str}' + [
f'\n\n调用路径参数已自动修正到: \n\n{txt}' + "我上传了文件,请查收",
f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数' + upload_msg]) f"[Local Message] 收到以下文件: \n\n{moved_files_str}"
+ f"\n\n调用路径参数已自动修正到: \n\n{txt}"
+ f"\n\n现在您点击任意函数插件时,以上文件将被作为输入参数"
+ upload_msg,
]
)
txt, txt2 = target_path_base, "" txt, txt2 = target_path_base, ""
if "浮动输入区" in checkboxes: if "浮动输入区" in checkboxes:
txt, txt2 = txt2, txt txt, txt2 = txt2, txt
# 记录近期文件 # 记录近期文件
cookies.update({ cookies.update(
'most_recent_uploaded': { {
'path': target_path_base, "most_recent_uploaded": {
'time': time.time(), "path": target_path_base,
'time_str': time_tag "time": time.time(),
}}) "time_str": time_tag,
}
}
)
return chatbot, txt, txt2, cookies return chatbot, txt, txt2, cookies
def on_report_generated(cookies, files, chatbot): def on_report_generated(cookies, files, chatbot):
# from toolbox import find_recent_files # from toolbox import find_recent_files
# PATH_LOGGING = get_conf('PATH_LOGGING') # PATH_LOGGING = get_conf('PATH_LOGGING')
if 'files_to_promote' in cookies: if "files_to_promote" in cookies:
report_files = cookies['files_to_promote'] report_files = cookies["files_to_promote"]
cookies.pop('files_to_promote') cookies.pop("files_to_promote")
else: else:
report_files = [] report_files = []
# report_files = find_recent_files(PATH_LOGGING) # report_files = find_recent_files(PATH_LOGGING)
if len(report_files) == 0: if len(report_files) == 0:
return cookies, None, chatbot return cookies, None, chatbot
# files.extend(report_files) # files.extend(report_files)
file_links = '' file_links = ""
for f in report_files: file_links += f'<br/><a href="file={os.path.abspath(f)}" target="_blank">{f}</a>' for f in report_files:
chatbot.append(['报告如何远程获取?', f'报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}']) file_links += (
f'<br/><a href="file={os.path.abspath(f)}" target="_blank">{f}</a>'
)
chatbot.append(["报告如何远程获取?", f"报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}"])
return cookies, report_files, chatbot return cookies, report_files, chatbot
def load_chat_cookies(): def load_chat_cookies():
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY') API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf(
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN') "API_KEY", "LLM_MODEL", "AZURE_API_KEY"
)
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf(
"AZURE_CFG_ARRAY", "NUM_CUSTOM_BASIC_BTN"
)
# deal with azure openai key # deal with azure openai key
if is_any_api_key(AZURE_API_KEY): if is_any_api_key(AZURE_API_KEY):
if is_any_api_key(API_KEY): if is_any_api_key(API_KEY):
API_KEY = API_KEY + ',' + AZURE_API_KEY API_KEY = API_KEY + "," + AZURE_API_KEY
else: else:
API_KEY = AZURE_API_KEY API_KEY = AZURE_API_KEY
if len(AZURE_CFG_ARRAY) > 0: if len(AZURE_CFG_ARRAY) > 0:
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items(): for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
if not azure_model_name.startswith('azure'): if not azure_model_name.startswith("azure"):
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头") raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"] AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
if is_any_api_key(AZURE_API_KEY_): if is_any_api_key(AZURE_API_KEY_):
if is_any_api_key(API_KEY): if is_any_api_key(API_KEY):
API_KEY = API_KEY + ',' + AZURE_API_KEY_ API_KEY = API_KEY + "," + AZURE_API_KEY_
else: else:
API_KEY = AZURE_API_KEY_ API_KEY = AZURE_API_KEY_
customize_fn_overwrite_ = {} customize_fn_overwrite_ = {}
for k in range(NUM_CUSTOM_BASIC_BTN): for k in range(NUM_CUSTOM_BASIC_BTN):
customize_fn_overwrite_.update({ customize_fn_overwrite_.update(
"自定义按钮" + str(k+1):{ {
"自定义按钮"
+ str(k + 1): {
"Title": r"", "Title": r"",
"Prefix": r"请在自定义菜单中定义提示词前缀.", "Prefix": r"请在自定义菜单中定义提示词前缀.",
"Suffix": r"请在自定义菜单中定义提示词后缀", "Suffix": r"请在自定义菜单中定义提示词后缀",
} }
}) }
return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_} )
return {
"api_key": API_KEY,
"llm_model": LLM_MODEL,
"customize_fn_overwrite": customize_fn_overwrite_,
}
def clear_line_break(txt): def clear_line_break(txt):
txt = txt.replace('\n', ' ') txt = txt.replace("\n", " ")
txt = txt.replace(' ', ' ') txt = txt.replace(" ", " ")
txt = txt.replace(' ', ' ') txt = txt.replace(" ", " ")
return txt return txt
class DummyWith(): class DummyWith:
""" """
这段代码定义了一个名为DummyWith的空上下文管理器 这段代码定义了一个名为DummyWith的空上下文管理器
它的作用是就是不起作用即在代码结构不变得情况下取代其他的上下文管理器 它的作用是就是不起作用即在代码结构不变得情况下取代其他的上下文管理器
@ -626,32 +682,45 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
""" """
def is_path_legal(path: str) -> bool: def is_path_legal(path: str) -> bool:
''' """
check path for sub url check path for sub url
path: path to check path: path to check
return value: do sub url wrap return value: do sub url wrap
''' """
if path == "/": return True if path == "/":
return True
if len(path) == 0: if len(path) == 0:
print("ilegal custom path: {}\npath must not be empty\ndeploy on root url".format(path)) print(
"ilegal custom path: {}\npath must not be empty\ndeploy on root url".format(
path
)
)
return False return False
if path[0] == '/': if path[0] == "/":
if path[1] != '/': if path[1] != "/":
print("deploy on sub-path {}".format(path)) print("deploy on sub-path {}".format(path))
return True return True
return False return False
print("ilegal custom path: {}\npath should begin with \'/\'\ndeploy on root url".format(path)) print(
"ilegal custom path: {}\npath should begin with '/'\ndeploy on root url".format(
path
)
)
return False return False
if not is_path_legal(custom_path): raise RuntimeError('Ilegal custom path') if not is_path_legal(custom_path):
raise RuntimeError("Ilegal custom path")
import uvicorn import uvicorn
import gradio as gr import gradio as gr
from fastapi import FastAPI from fastapi import FastAPI
app = FastAPI() app = FastAPI()
if custom_path != "/": if custom_path != "/":
@app.get("/") @app.get("/")
def read_main(): def read_main():
return {"message": f"Gradio is running at: {custom_path}"} return {"message": f"Gradio is running at: {custom_path}"}
app = gr.mount_gradio_app(app, demo, path=custom_path) app = gr.mount_gradio_app(app, demo, path=custom_path)
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
@ -667,13 +736,18 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
""" """
import numpy as np import numpy as np
from request_llms.bridge_all import model_info from request_llms.bridge_all import model_info
def get_token_num(txt): def get_token_num(txt):
return len(tokenizer.encode(txt, disallowed_special=())) return len(tokenizer.encode(txt, disallowed_special=()))
input_token_num = get_token_num(inputs) input_token_num = get_token_num(inputs)
if max_token_limit < 5000: output_token_expect = 256 # 4k & 2k models if max_token_limit < 5000:
elif max_token_limit < 9000: output_token_expect = 512 # 8k models output_token_expect = 256 # 4k & 2k models
else: output_token_expect = 1024 # 16k & 32k models elif max_token_limit < 9000:
output_token_expect = 512 # 8k models
else:
output_token_expect = 1024 # 16k & 32k models
if input_token_num < max_token_limit * 3 / 4: if input_token_num < max_token_limit * 3 / 4:
# 当输入部分的token占比小于限制的3/4时裁剪时 # 当输入部分的token占比小于限制的3/4时裁剪时
@ -690,9 +764,9 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
history = [] history = []
return history return history
everything = [''] everything = [""]
everything.extend(history) everything.extend(history)
n_token = get_token_num('\n'.join(everything)) n_token = get_token_num("\n".join(everything))
everything_token = [get_token_num(e) for e in everything] everything_token = [get_token_num(e) for e in everything]
# 截断时的颗粒度 # 截断时的颗粒度
@ -702,29 +776,32 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
where = np.argmax(everything_token) where = np.argmax(everything_token)
encoded = tokenizer.encode(everything[where], disallowed_special=()) encoded = tokenizer.encode(everything[where], disallowed_special=())
clipped_encoded = encoded[: len(encoded) - delta] clipped_encoded = encoded[: len(encoded) - delta]
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char everything[where] = tokenizer.decode(clipped_encoded)[
:-1
] # -1 to remove the may-be illegal char
everything_token[where] = get_token_num(everything[where]) everything_token[where] = get_token_num(everything[where])
n_token = get_token_num('\n'.join(everything)) n_token = get_token_num("\n".join(everything))
history = everything[1:] history = everything[1:]
return history return history
""" """
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
第三部分 第三部分
其他小工具: 其他小工具:
- zip_folder: 把某个路径下所有文件压缩然后转移到指定的另一个路径中gpt写的 - zip_folder: 把某个路径下所有文件压缩然后转移到指定的另一个路径中gpt写的
- gen_time_str: 生成时间戳 - gen_time_str: 生成时间戳
- ProxyNetworkActivate: 临时地启动代理网络如果有 - ProxyNetworkActivate: 临时地启动代理网络如果有
- objdump/objload: 快捷的调试函数 - objdump/objload: 快捷的调试函数
======================================================================== =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
""" """
def zip_folder(source_folder, dest_folder, zip_name): def zip_folder(source_folder, dest_folder, zip_name):
import zipfile import zipfile
import os import os
# Make sure the source folder exists # Make sure the source folder exists
if not os.path.exists(source_folder): if not os.path.exists(source_folder):
print(f"{source_folder} does not exist") print(f"{source_folder} does not exist")
@ -739,7 +816,7 @@ def zip_folder(source_folder, dest_folder, zip_name):
zip_file = pj(dest_folder, zip_name) zip_file = pj(dest_folder, zip_name)
# Create a ZipFile object # Create a ZipFile object
with zipfile.ZipFile(zip_file, 'w', zipfile.ZIP_DEFLATED) as zipf: with zipfile.ZipFile(zip_file, "w", zipfile.ZIP_DEFLATED) as zipf:
# Walk through the source folder and add files to the zip file # Walk through the source folder and add files to the zip file
for foldername, subfolders, filenames in os.walk(source_folder): for foldername, subfolders, filenames in os.walk(source_folder):
for filename in filenames: for filename in filenames:
@ -756,29 +833,33 @@ def zip_folder(source_folder, dest_folder, zip_name):
def zip_result(folder): def zip_result(folder):
t = gen_time_str() t = gen_time_str()
zip_folder(folder, get_log_folder(), f'{t}-result.zip') zip_folder(folder, get_log_folder(), f"{t}-result.zip")
return pj(get_log_folder(), f'{t}-result.zip') return pj(get_log_folder(), f"{t}-result.zip")
def gen_time_str(): def gen_time_str():
import time import time
return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
def get_log_folder(user=default_user_name, plugin_name='shared'): def get_log_folder(user=default_user_name, plugin_name="shared"):
if user is None: user = default_user_name if user is None:
PATH_LOGGING = get_conf('PATH_LOGGING') user = default_user_name
PATH_LOGGING = get_conf("PATH_LOGGING")
if plugin_name is None: if plugin_name is None:
_dir = pj(PATH_LOGGING, user) _dir = pj(PATH_LOGGING, user)
else: else:
_dir = pj(PATH_LOGGING, user, plugin_name) _dir = pj(PATH_LOGGING, user, plugin_name)
if not os.path.exists(_dir): os.makedirs(_dir) if not os.path.exists(_dir):
os.makedirs(_dir)
return _dir return _dir
def get_upload_folder(user=default_user_name, tag=None): def get_upload_folder(user=default_user_name, tag=None):
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD') PATH_PRIVATE_UPLOAD = get_conf("PATH_PRIVATE_UPLOAD")
if user is None: user = default_user_name if user is None:
user = default_user_name
if tag is None or len(tag) == 0: if tag is None or len(tag) == 0:
target_path_base = pj(PATH_PRIVATE_UPLOAD, user) target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
else: else:
@ -787,9 +868,9 @@ def get_upload_folder(user=default_user_name, tag=None):
def is_the_upload_folder(string): def is_the_upload_folder(string):
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD') PATH_PRIVATE_UPLOAD = get_conf("PATH_PRIVATE_UPLOAD")
pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$' pattern = r"^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$"
pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD) pattern = pattern.replace("PATH_PRIVATE_UPLOAD", PATH_PRIVATE_UPLOAD)
if re.match(pattern, string): if re.match(pattern, string):
return True return True
else: else:
@ -797,10 +878,10 @@ def is_the_upload_folder(string):
def get_user(chatbotwithcookies): def get_user(chatbotwithcookies):
return chatbotwithcookies._cookies.get('user_name', default_user_name) return chatbotwithcookies._cookies.get("user_name", default_user_name)
class ProxyNetworkActivate(): class ProxyNetworkActivate:
""" """
这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理 这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
""" """
@ -813,38 +894,48 @@ class ProxyNetworkActivate():
else: else:
# 给定了task, 我们检查一下 # 给定了task, 我们检查一下
from toolbox import get_conf from toolbox import get_conf
WHEN_TO_USE_PROXY = get_conf('WHEN_TO_USE_PROXY')
self.valid = (task in WHEN_TO_USE_PROXY) WHEN_TO_USE_PROXY = get_conf("WHEN_TO_USE_PROXY")
self.valid = task in WHEN_TO_USE_PROXY
def __enter__(self): def __enter__(self):
if not self.valid: return self if not self.valid:
return self
from toolbox import get_conf from toolbox import get_conf
proxies = get_conf('proxies')
if 'no_proxy' in os.environ: os.environ.pop('no_proxy') proxies = get_conf("proxies")
if "no_proxy" in os.environ:
os.environ.pop("no_proxy")
if proxies is not None: if proxies is not None:
if 'http' in proxies: os.environ['HTTP_PROXY'] = proxies['http'] if "http" in proxies:
if 'https' in proxies: os.environ['HTTPS_PROXY'] = proxies['https'] os.environ["HTTP_PROXY"] = proxies["http"]
if "https" in proxies:
os.environ["HTTPS_PROXY"] = proxies["https"]
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
os.environ['no_proxy'] = '*' os.environ["no_proxy"] = "*"
if 'HTTP_PROXY' in os.environ: os.environ.pop('HTTP_PROXY') if "HTTP_PROXY" in os.environ:
if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY') os.environ.pop("HTTP_PROXY")
if "HTTPS_PROXY" in os.environ:
os.environ.pop("HTTPS_PROXY")
return return
def objdump(obj, file='objdump.tmp'): def objdump(obj, file="objdump.tmp"):
import pickle import pickle
with open(file, 'wb+') as f:
with open(file, "wb+") as f:
pickle.dump(obj, f) pickle.dump(obj, f)
return return
def objload(file='objdump.tmp'): def objload(file="objdump.tmp"):
import pickle, os import pickle, os
if not os.path.exists(file): if not os.path.exists(file):
return return
with open(file, 'rb') as f: with open(file, "rb") as f:
return pickle.load(f) return pickle.load(f)
@ -863,22 +954,25 @@ def Singleton(cls):
def get_pictures_list(path): def get_pictures_list(path):
file_manifest = [f for f in glob.glob(f'{path}/**/*.jpg', recursive=True)] file_manifest = [f for f in glob.glob(f"{path}/**/*.jpg", recursive=True)]
file_manifest += [f for f in glob.glob(f'{path}/**/*.jpeg', recursive=True)] file_manifest += [f for f in glob.glob(f"{path}/**/*.jpeg", recursive=True)]
file_manifest += [f for f in glob.glob(f'{path}/**/*.png', recursive=True)] file_manifest += [f for f in glob.glob(f"{path}/**/*.png", recursive=True)]
return file_manifest return file_manifest
def have_any_recent_upload_image_files(chatbot): def have_any_recent_upload_image_files(chatbot):
_5min = 5 * 60 _5min = 5 * 60
if chatbot is None: return False, None # chatbot is None if chatbot is None:
return False, None # chatbot is None
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None) most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
if not most_recent_uploaded: return False, None # most_recent_uploaded is None if not most_recent_uploaded:
return False, None # most_recent_uploaded is None
if time.time() - most_recent_uploaded["time"] < _5min: if time.time() - most_recent_uploaded["time"] < _5min:
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None) most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
path = most_recent_uploaded['path'] path = most_recent_uploaded["path"]
file_manifest = get_pictures_list(path) file_manifest = get_pictures_list(path)
if len(file_manifest) == 0: return False, None if len(file_manifest) == 0:
return False, None
return True, file_manifest # most_recent_uploaded is new return True, file_manifest # most_recent_uploaded is new
else: else:
return False, None # most_recent_uploaded is too old return False, None # most_recent_uploaded is too old
@ -887,16 +981,19 @@ def have_any_recent_upload_image_files(chatbot):
# Function to encode the image # Function to encode the image
def encode_image(image_path): def encode_image(image_path):
with open(image_path, "rb") as image_file: with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8') return base64.b64encode(image_file.read()).decode("utf-8")
def get_max_token(llm_kwargs): def get_max_token(llm_kwargs):
from request_llms.bridge_all import model_info from request_llms.bridge_all import model_info
return model_info[llm_kwargs['llm_model']]['max_token']
return model_info[llm_kwargs["llm_model"]]["max_token"]
def check_packages(packages=[]): def check_packages(packages=[]):
import importlib.util import importlib.util
for p in packages: for p in packages:
spam_spec = importlib.util.find_spec(p) spam_spec = importlib.util.find_spec(p)
if spam_spec is None: raise ModuleNotFoundError if spam_spec is None:
raise ModuleNotFoundError