format source code
This commit is contained in:
parent
1714116a89
commit
7ab379688e
@ -5,7 +5,7 @@ import glob, os, requests, time
|
|||||||
pj = os.path.join
|
pj = os.path.join
|
||||||
ARXIV_CACHE_DIR = os.path.expanduser(f"~/arxiv_cache/")
|
ARXIV_CACHE_DIR = os.path.expanduser(f"~/arxiv_cache/")
|
||||||
|
|
||||||
# =================================== 工具函数 ===============================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- 工具函数 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
# 专业词汇声明 = 'If the term "agent" is used in this section, it should be translated to "智能体". '
|
# 专业词汇声明 = 'If the term "agent" is used in this section, it should be translated to "智能体". '
|
||||||
def switch_prompt(pfg, mode, more_requirement):
|
def switch_prompt(pfg, mode, more_requirement):
|
||||||
"""
|
"""
|
||||||
@ -142,7 +142,7 @@ def arxiv_download(chatbot, history, txt, allow_cache=True):
|
|||||||
from toolbox import extract_archive
|
from toolbox import extract_archive
|
||||||
extract_archive(file_path=dst, dest_dir=extract_dst)
|
extract_archive(file_path=dst, dest_dir=extract_dst)
|
||||||
return extract_dst, arxiv_id
|
return extract_dst, arxiv_id
|
||||||
# ========================================= 插件主程序1 =====================================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序1 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||||
|
|
||||||
|
|
||||||
@CatchException
|
@CatchException
|
||||||
@ -218,7 +218,7 @@ def Latex英文纠错加PDF对比(txt, llm_kwargs, plugin_kwargs, chatbot, histo
|
|||||||
# <-------------- we are done ------------->
|
# <-------------- we are done ------------->
|
||||||
return success
|
return success
|
||||||
|
|
||||||
# ========================================= 插件主程序2 =====================================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序2 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||||
|
|
||||||
@CatchException
|
@CatchException
|
||||||
def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
|
||||||
|
@ -1,15 +1,18 @@
|
|||||||
import os, shutil
|
import os, shutil
|
||||||
import re
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
PRESERVE = 0
|
PRESERVE = 0
|
||||||
TRANSFORM = 1
|
TRANSFORM = 1
|
||||||
|
|
||||||
pj = os.path.join
|
pj = os.path.join
|
||||||
|
|
||||||
class LinkedListNode():
|
|
||||||
|
class LinkedListNode:
|
||||||
"""
|
"""
|
||||||
Linked List Node
|
Linked List Node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, string, preserve=True) -> None:
|
def __init__(self, string, preserve=True) -> None:
|
||||||
self.string = string
|
self.string = string
|
||||||
self.preserve = preserve
|
self.preserve = preserve
|
||||||
@ -18,12 +21,14 @@ class LinkedListNode():
|
|||||||
# self.begin_line = 0
|
# self.begin_line = 0
|
||||||
# self.begin_char = 0
|
# self.begin_char = 0
|
||||||
|
|
||||||
|
|
||||||
def convert_to_linklist(text, mask):
|
def convert_to_linklist(text, mask):
|
||||||
root = LinkedListNode("", preserve=True)
|
root = LinkedListNode("", preserve=True)
|
||||||
current_node = root
|
current_node = root
|
||||||
for c, m, i in zip(text, mask, range(len(text))):
|
for c, m, i in zip(text, mask, range(len(text))):
|
||||||
if (m==PRESERVE and current_node.preserve) \
|
if (m == PRESERVE and current_node.preserve) or (
|
||||||
or (m==TRANSFORM and not current_node.preserve):
|
m == TRANSFORM and not current_node.preserve
|
||||||
|
):
|
||||||
# add
|
# add
|
||||||
current_node.string += c
|
current_node.string += c
|
||||||
else:
|
else:
|
||||||
@ -31,6 +36,7 @@ def convert_to_linklist(text, mask):
|
|||||||
current_node = current_node.next
|
current_node = current_node.next
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
|
||||||
def post_process(root):
|
def post_process(root):
|
||||||
# 修复括号
|
# 修复括号
|
||||||
node = root
|
node = root
|
||||||
@ -38,21 +44,24 @@ def post_process(root):
|
|||||||
string = node.string
|
string = node.string
|
||||||
if node.preserve:
|
if node.preserve:
|
||||||
node = node.next
|
node = node.next
|
||||||
if node is None: break
|
if node is None:
|
||||||
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def break_check(string):
|
def break_check(string):
|
||||||
str_stack = [""] # (lv, index)
|
str_stack = [""] # (lv, index)
|
||||||
for i, c in enumerate(string):
|
for i, c in enumerate(string):
|
||||||
if c == '{':
|
if c == "{":
|
||||||
str_stack.append('{')
|
str_stack.append("{")
|
||||||
elif c == '}':
|
elif c == "}":
|
||||||
if len(str_stack) == 1:
|
if len(str_stack) == 1:
|
||||||
print('stack fix')
|
print("stack fix")
|
||||||
return i
|
return i
|
||||||
str_stack.pop(-1)
|
str_stack.pop(-1)
|
||||||
else:
|
else:
|
||||||
str_stack[-1] += c
|
str_stack[-1] += c
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
bp = break_check(string)
|
bp = break_check(string)
|
||||||
|
|
||||||
if bp == -1:
|
if bp == -1:
|
||||||
@ -69,51 +78,66 @@ def post_process(root):
|
|||||||
node.next = q
|
node.next = q
|
||||||
|
|
||||||
node = node.next
|
node = node.next
|
||||||
if node is None: break
|
if node is None:
|
||||||
|
break
|
||||||
|
|
||||||
# 屏蔽空行和太短的句子
|
# 屏蔽空行和太短的句子
|
||||||
node = root
|
node = root
|
||||||
while True:
|
while True:
|
||||||
if len(node.string.strip('\n').strip(''))==0: node.preserve = True
|
if len(node.string.strip("\n").strip("")) == 0:
|
||||||
if len(node.string.strip('\n').strip(''))<42: node.preserve = True
|
node.preserve = True
|
||||||
|
if len(node.string.strip("\n").strip("")) < 42:
|
||||||
|
node.preserve = True
|
||||||
node = node.next
|
node = node.next
|
||||||
if node is None: break
|
if node is None:
|
||||||
|
break
|
||||||
node = root
|
node = root
|
||||||
while True:
|
while True:
|
||||||
if node.next and node.preserve and node.next.preserve:
|
if node.next and node.preserve and node.next.preserve:
|
||||||
node.string += node.next.string
|
node.string += node.next.string
|
||||||
node.next = node.next.next
|
node.next = node.next.next
|
||||||
node = node.next
|
node = node.next
|
||||||
if node is None: break
|
if node is None:
|
||||||
|
break
|
||||||
|
|
||||||
# 将前后断行符脱离
|
# 将前后断行符脱离
|
||||||
node = root
|
node = root
|
||||||
prev_node = None
|
prev_node = None
|
||||||
while True:
|
while True:
|
||||||
if not node.preserve:
|
if not node.preserve:
|
||||||
lstriped_ = node.string.lstrip().lstrip('\n')
|
lstriped_ = node.string.lstrip().lstrip("\n")
|
||||||
if (prev_node is not None) and (prev_node.preserve) and (len(lstriped_)!=len(node.string)):
|
if (
|
||||||
|
(prev_node is not None)
|
||||||
|
and (prev_node.preserve)
|
||||||
|
and (len(lstriped_) != len(node.string))
|
||||||
|
):
|
||||||
prev_node.string += node.string[: -len(lstriped_)]
|
prev_node.string += node.string[: -len(lstriped_)]
|
||||||
node.string = lstriped_
|
node.string = lstriped_
|
||||||
rstriped_ = node.string.rstrip().rstrip('\n')
|
rstriped_ = node.string.rstrip().rstrip("\n")
|
||||||
if (node.next is not None) and (node.next.preserve) and (len(rstriped_)!=len(node.string)):
|
if (
|
||||||
|
(node.next is not None)
|
||||||
|
and (node.next.preserve)
|
||||||
|
and (len(rstriped_) != len(node.string))
|
||||||
|
):
|
||||||
node.next.string = node.string[len(rstriped_) :] + node.next.string
|
node.next.string = node.string[len(rstriped_) :] + node.next.string
|
||||||
node.string = rstriped_
|
node.string = rstriped_
|
||||||
# =====
|
# =-=-=
|
||||||
prev_node = node
|
prev_node = node
|
||||||
node = node.next
|
node = node.next
|
||||||
if node is None: break
|
if node is None:
|
||||||
|
break
|
||||||
|
|
||||||
# 标注节点的行数范围
|
# 标注节点的行数范围
|
||||||
node = root
|
node = root
|
||||||
n_line = 0
|
n_line = 0
|
||||||
expansion = 2
|
expansion = 2
|
||||||
while True:
|
while True:
|
||||||
n_l = node.string.count('\n')
|
n_l = node.string.count("\n")
|
||||||
node.range = [n_line - expansion, n_line + n_l + expansion] # 失败时,扭转的范围
|
node.range = [n_line - expansion, n_line + n_l + expansion] # 失败时,扭转的范围
|
||||||
n_line = n_line + n_l
|
n_line = n_line + n_l
|
||||||
node = node.next
|
node = node.next
|
||||||
if node is None: break
|
if node is None:
|
||||||
|
break
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
|
||||||
@ -131,12 +155,14 @@ def set_forbidden_text(text, mask, pattern, flags=0):
|
|||||||
you can mask out (mask = PRESERVE so that text become untouchable for GPT)
|
you can mask out (mask = PRESERVE so that text become untouchable for GPT)
|
||||||
everything between "\begin{equation}" and "\end{equation}"
|
everything between "\begin{equation}" and "\end{equation}"
|
||||||
"""
|
"""
|
||||||
if isinstance(pattern, list): pattern = '|'.join(pattern)
|
if isinstance(pattern, list):
|
||||||
|
pattern = "|".join(pattern)
|
||||||
pattern_compile = re.compile(pattern, flags)
|
pattern_compile = re.compile(pattern, flags)
|
||||||
for res in pattern_compile.finditer(text):
|
for res in pattern_compile.finditer(text):
|
||||||
mask[res.span()[0] : res.span()[1]] = PRESERVE
|
mask[res.span()[0] : res.span()[1]] = PRESERVE
|
||||||
return text, mask
|
return text, mask
|
||||||
|
|
||||||
|
|
||||||
def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
|
def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
|
||||||
"""
|
"""
|
||||||
Move area out of preserve area (make text editable for GPT)
|
Move area out of preserve area (make text editable for GPT)
|
||||||
@ -144,7 +170,8 @@ def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
|
|||||||
e.g.
|
e.g.
|
||||||
\begin{abstract} blablablablablabla. \end{abstract}
|
\begin{abstract} blablablablablabla. \end{abstract}
|
||||||
"""
|
"""
|
||||||
if isinstance(pattern, list): pattern = '|'.join(pattern)
|
if isinstance(pattern, list):
|
||||||
|
pattern = "|".join(pattern)
|
||||||
pattern_compile = re.compile(pattern, flags)
|
pattern_compile = re.compile(pattern, flags)
|
||||||
for res in pattern_compile.finditer(text):
|
for res in pattern_compile.finditer(text):
|
||||||
if not forbid_wrapper:
|
if not forbid_wrapper:
|
||||||
@ -155,6 +182,7 @@ def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
|
|||||||
mask[res.regs[1][1] : res.regs[0][1]] = PRESERVE # abstract
|
mask[res.regs[1][1] : res.regs[0][1]] = PRESERVE # abstract
|
||||||
return text, mask
|
return text, mask
|
||||||
|
|
||||||
|
|
||||||
def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
|
def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
|
||||||
"""
|
"""
|
||||||
Add a preserve text area in this paper (text become untouchable for GPT).
|
Add a preserve text area in this paper (text become untouchable for GPT).
|
||||||
@ -167,15 +195,21 @@ def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
|
|||||||
brace_level = -1
|
brace_level = -1
|
||||||
p = begin = end = res.regs[0][0]
|
p = begin = end = res.regs[0][0]
|
||||||
for _ in range(1024 * 16):
|
for _ in range(1024 * 16):
|
||||||
if text[p] == '}' and brace_level == 0: break
|
if text[p] == "}" and brace_level == 0:
|
||||||
elif text[p] == '}': brace_level -= 1
|
break
|
||||||
elif text[p] == '{': brace_level += 1
|
elif text[p] == "}":
|
||||||
|
brace_level -= 1
|
||||||
|
elif text[p] == "{":
|
||||||
|
brace_level += 1
|
||||||
p += 1
|
p += 1
|
||||||
end = p + 1
|
end = p + 1
|
||||||
mask[begin:end] = PRESERVE
|
mask[begin:end] = PRESERVE
|
||||||
return text, mask
|
return text, mask
|
||||||
|
|
||||||
def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wrapper=True):
|
|
||||||
|
def reverse_forbidden_text_careful_brace(
|
||||||
|
text, mask, pattern, flags=0, forbid_wrapper=True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Move area out of preserve area (make text editable for GPT)
|
Move area out of preserve area (make text editable for GPT)
|
||||||
count the number of the braces so as to catch compelete text area.
|
count the number of the braces so as to catch compelete text area.
|
||||||
@ -187,9 +221,12 @@ def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wr
|
|||||||
brace_level = 0
|
brace_level = 0
|
||||||
p = begin = end = res.regs[1][0]
|
p = begin = end = res.regs[1][0]
|
||||||
for _ in range(1024 * 16):
|
for _ in range(1024 * 16):
|
||||||
if text[p] == '}' and brace_level == 0: break
|
if text[p] == "}" and brace_level == 0:
|
||||||
elif text[p] == '}': brace_level -= 1
|
break
|
||||||
elif text[p] == '{': brace_level += 1
|
elif text[p] == "}":
|
||||||
|
brace_level -= 1
|
||||||
|
elif text[p] == "{":
|
||||||
|
brace_level += 1
|
||||||
p += 1
|
p += 1
|
||||||
end = p
|
end = p
|
||||||
mask[begin:end] = TRANSFORM
|
mask[begin:end] = TRANSFORM
|
||||||
@ -198,27 +235,42 @@ def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wr
|
|||||||
mask[end : res.regs[0][1]] = PRESERVE
|
mask[end : res.regs[0][1]] = PRESERVE
|
||||||
return text, mask
|
return text, mask
|
||||||
|
|
||||||
|
|
||||||
def set_forbidden_text_begin_end(text, mask, pattern, flags=0, limit_n_lines=42):
|
def set_forbidden_text_begin_end(text, mask, pattern, flags=0, limit_n_lines=42):
|
||||||
"""
|
"""
|
||||||
Find all \begin{} ... \end{} text block that with less than limit_n_lines lines.
|
Find all \begin{} ... \end{} text block that with less than limit_n_lines lines.
|
||||||
Add it to preserve area
|
Add it to preserve area
|
||||||
"""
|
"""
|
||||||
pattern_compile = re.compile(pattern, flags)
|
pattern_compile = re.compile(pattern, flags)
|
||||||
|
|
||||||
def search_with_line_limit(text, mask):
|
def search_with_line_limit(text, mask):
|
||||||
for res in pattern_compile.finditer(text):
|
for res in pattern_compile.finditer(text):
|
||||||
cmd = res.group(1) # begin{what}
|
cmd = res.group(1) # begin{what}
|
||||||
this = res.group(2) # content between begin and end
|
this = res.group(2) # content between begin and end
|
||||||
this_mask = mask[res.regs[2][0] : res.regs[2][1]]
|
this_mask = mask[res.regs[2][0] : res.regs[2][1]]
|
||||||
white_list = ['document', 'abstract', 'lemma', 'definition', 'sproof',
|
white_list = [
|
||||||
'em', 'emph', 'textit', 'textbf', 'itemize', 'enumerate']
|
"document",
|
||||||
if (cmd in white_list) or this.count('\n') >= limit_n_lines: # use a magical number 42
|
"abstract",
|
||||||
|
"lemma",
|
||||||
|
"definition",
|
||||||
|
"sproof",
|
||||||
|
"em",
|
||||||
|
"emph",
|
||||||
|
"textit",
|
||||||
|
"textbf",
|
||||||
|
"itemize",
|
||||||
|
"enumerate",
|
||||||
|
]
|
||||||
|
if (cmd in white_list) or this.count(
|
||||||
|
"\n"
|
||||||
|
) >= limit_n_lines: # use a magical number 42
|
||||||
this, this_mask = search_with_line_limit(this, this_mask)
|
this, this_mask = search_with_line_limit(this, this_mask)
|
||||||
mask[res.regs[2][0] : res.regs[2][1]] = this_mask
|
mask[res.regs[2][0] : res.regs[2][1]] = this_mask
|
||||||
else:
|
else:
|
||||||
mask[res.regs[0][0] : res.regs[0][1]] = PRESERVE
|
mask[res.regs[0][0] : res.regs[0][1]] = PRESERVE
|
||||||
return text, mask
|
return text, mask
|
||||||
return search_with_line_limit(text, mask)
|
|
||||||
|
|
||||||
|
return search_with_line_limit(text, mask)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -227,6 +279,7 @@ Latex Merge File
|
|||||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def find_main_tex_file(file_manifest, mode):
|
def find_main_tex_file(file_manifest, mode):
|
||||||
"""
|
"""
|
||||||
在多Tex文档中,寻找主文件,必须包含documentclass,返回找到的第一个。
|
在多Tex文档中,寻找主文件,必须包含documentclass,返回找到的第一个。
|
||||||
@ -234,27 +287,36 @@ def find_main_tex_file(file_manifest, mode):
|
|||||||
"""
|
"""
|
||||||
canidates = []
|
canidates = []
|
||||||
for texf in file_manifest:
|
for texf in file_manifest:
|
||||||
if os.path.basename(texf).startswith('merge'):
|
if os.path.basename(texf).startswith("merge"):
|
||||||
continue
|
continue
|
||||||
with open(texf, 'r', encoding='utf8', errors='ignore') as f:
|
with open(texf, "r", encoding="utf8", errors="ignore") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
if r'\documentclass' in file_content:
|
if r"\documentclass" in file_content:
|
||||||
canidates.append(texf)
|
canidates.append(texf)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(canidates) == 0:
|
if len(canidates) == 0:
|
||||||
raise RuntimeError('无法找到一个主Tex文件(包含documentclass关键字)')
|
raise RuntimeError("无法找到一个主Tex文件(包含documentclass关键字)")
|
||||||
elif len(canidates) == 1:
|
elif len(canidates) == 1:
|
||||||
return canidates[0]
|
return canidates[0]
|
||||||
else: # if len(canidates) >= 2 通过一些Latex模板中常见(但通常不会出现在正文)的单词,对不同latex源文件扣分,取评分最高者返回
|
else: # if len(canidates) >= 2 通过一些Latex模板中常见(但通常不会出现在正文)的单词,对不同latex源文件扣分,取评分最高者返回
|
||||||
canidates_score = []
|
canidates_score = []
|
||||||
# 给出一些判定模板文档的词作为扣分项
|
# 给出一些判定模板文档的词作为扣分项
|
||||||
unexpected_words = ['\\LaTeX', 'manuscript', 'Guidelines', 'font', 'citations', 'rejected', 'blind review', 'reviewers']
|
unexpected_words = [
|
||||||
expected_words = ['\\input', '\\ref', '\\cite']
|
"\\LaTeX",
|
||||||
|
"manuscript",
|
||||||
|
"Guidelines",
|
||||||
|
"font",
|
||||||
|
"citations",
|
||||||
|
"rejected",
|
||||||
|
"blind review",
|
||||||
|
"reviewers",
|
||||||
|
]
|
||||||
|
expected_words = ["\\input", "\\ref", "\\cite"]
|
||||||
for texf in canidates:
|
for texf in canidates:
|
||||||
canidates_score.append(0)
|
canidates_score.append(0)
|
||||||
with open(texf, 'r', encoding='utf8', errors='ignore') as f:
|
with open(texf, "r", encoding="utf8", errors="ignore") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
file_content = rm_comments(file_content)
|
file_content = rm_comments(file_content)
|
||||||
for uw in unexpected_words:
|
for uw in unexpected_words:
|
||||||
@ -266,6 +328,7 @@ def find_main_tex_file(file_manifest, mode):
|
|||||||
select = np.argmax(canidates_score) # 取评分最高者返回
|
select = np.argmax(canidates_score) # 取评分最高者返回
|
||||||
return canidates[select]
|
return canidates[select]
|
||||||
|
|
||||||
|
|
||||||
def rm_comments(main_file):
|
def rm_comments(main_file):
|
||||||
new_file_remove_comment_lines = []
|
new_file_remove_comment_lines = []
|
||||||
for l in main_file.splitlines():
|
for l in main_file.splitlines():
|
||||||
@ -274,30 +337,39 @@ def rm_comments(main_file):
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
new_file_remove_comment_lines.append(l)
|
new_file_remove_comment_lines.append(l)
|
||||||
main_file = '\n'.join(new_file_remove_comment_lines)
|
main_file = "\n".join(new_file_remove_comment_lines)
|
||||||
# main_file = re.sub(r"\\include{(.*?)}", r"\\input{\1}", main_file) # 将 \include 命令转换为 \input 命令
|
# main_file = re.sub(r"\\include{(.*?)}", r"\\input{\1}", main_file) # 将 \include 命令转换为 \input 命令
|
||||||
main_file = re.sub(r'(?<!\\)%.*', '', main_file) # 使用正则表达式查找半行注释, 并替换为空字符串
|
main_file = re.sub(r"(?<!\\)%.*", "", main_file) # 使用正则表达式查找半行注释, 并替换为空字符串
|
||||||
return main_file
|
return main_file
|
||||||
|
|
||||||
|
|
||||||
def find_tex_file_ignore_case(fp):
|
def find_tex_file_ignore_case(fp):
|
||||||
dir_name = os.path.dirname(fp)
|
dir_name = os.path.dirname(fp)
|
||||||
base_name = os.path.basename(fp)
|
base_name = os.path.basename(fp)
|
||||||
# 如果输入的文件路径是正确的
|
# 如果输入的文件路径是正确的
|
||||||
if os.path.isfile(pj(dir_name, base_name)): return pj(dir_name, base_name)
|
if os.path.isfile(pj(dir_name, base_name)):
|
||||||
|
return pj(dir_name, base_name)
|
||||||
# 如果不正确,试着加上.tex后缀试试
|
# 如果不正确,试着加上.tex后缀试试
|
||||||
if not base_name.endswith('.tex'): base_name+='.tex'
|
if not base_name.endswith(".tex"):
|
||||||
if os.path.isfile(pj(dir_name, base_name)): return pj(dir_name, base_name)
|
base_name += ".tex"
|
||||||
|
if os.path.isfile(pj(dir_name, base_name)):
|
||||||
|
return pj(dir_name, base_name)
|
||||||
# 如果还找不到,解除大小写限制,再试一次
|
# 如果还找不到,解除大小写限制,再试一次
|
||||||
import glob
|
import glob
|
||||||
for f in glob.glob(dir_name+'/*.tex'):
|
|
||||||
|
for f in glob.glob(dir_name + "/*.tex"):
|
||||||
base_name_s = os.path.basename(fp)
|
base_name_s = os.path.basename(fp)
|
||||||
base_name_f = os.path.basename(f)
|
base_name_f = os.path.basename(f)
|
||||||
if base_name_s.lower() == base_name_f.lower(): return f
|
if base_name_s.lower() == base_name_f.lower():
|
||||||
|
return f
|
||||||
# 试着加上.tex后缀试试
|
# 试着加上.tex后缀试试
|
||||||
if not base_name_s.endswith('.tex'): base_name_s+='.tex'
|
if not base_name_s.endswith(".tex"):
|
||||||
if base_name_s.lower() == base_name_f.lower(): return f
|
base_name_s += ".tex"
|
||||||
|
if base_name_s.lower() == base_name_f.lower():
|
||||||
|
return f
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def merge_tex_files_(project_foler, main_file, mode):
|
def merge_tex_files_(project_foler, main_file, mode):
|
||||||
"""
|
"""
|
||||||
Merge Tex project recrusively
|
Merge Tex project recrusively
|
||||||
@ -309,18 +381,18 @@ def merge_tex_files_(project_foler, main_file, mode):
|
|||||||
fp_ = find_tex_file_ignore_case(fp)
|
fp_ = find_tex_file_ignore_case(fp)
|
||||||
if fp_:
|
if fp_:
|
||||||
try:
|
try:
|
||||||
with open(fp_, 'r', encoding='utf-8', errors='replace') as fx: c = fx.read()
|
with open(fp_, "r", encoding="utf-8", errors="replace") as fx:
|
||||||
|
c = fx.read()
|
||||||
except:
|
except:
|
||||||
c = f"\n\nWarning from GPT-Academic: LaTex source file is missing!\n\n"
|
c = f"\n\nWarning from GPT-Academic: LaTex source file is missing!\n\n"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f'找不到{fp},Tex源文件缺失!')
|
raise RuntimeError(f"找不到{fp},Tex源文件缺失!")
|
||||||
c = merge_tex_files_(project_foler, c, mode)
|
c = merge_tex_files_(project_foler, c, mode)
|
||||||
main_file = main_file[: s.span()[0]] + c + main_file[s.span()[1] :]
|
main_file = main_file[: s.span()[0]] + c + main_file[s.span()[1] :]
|
||||||
return main_file
|
return main_file
|
||||||
|
|
||||||
|
|
||||||
def find_title_and_abs(main_file):
|
def find_title_and_abs(main_file):
|
||||||
|
|
||||||
def extract_abstract_1(text):
|
def extract_abstract_1(text):
|
||||||
pattern = r"\\abstract\{(.*?)\}"
|
pattern = r"\\abstract\{(.*?)\}"
|
||||||
match = re.search(pattern, text, re.DOTALL)
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
@ -362,21 +434,30 @@ def merge_tex_files(project_foler, main_file, mode):
|
|||||||
main_file = merge_tex_files_(project_foler, main_file, mode)
|
main_file = merge_tex_files_(project_foler, main_file, mode)
|
||||||
main_file = rm_comments(main_file)
|
main_file = rm_comments(main_file)
|
||||||
|
|
||||||
if mode == 'translate_zh':
|
if mode == "translate_zh":
|
||||||
# find paper documentclass
|
# find paper documentclass
|
||||||
pattern = re.compile(r'\\documentclass.*\n')
|
pattern = re.compile(r"\\documentclass.*\n")
|
||||||
match = pattern.search(main_file)
|
match = pattern.search(main_file)
|
||||||
assert match is not None, "Cannot find documentclass statement!"
|
assert match is not None, "Cannot find documentclass statement!"
|
||||||
position = match.end()
|
position = match.end()
|
||||||
add_ctex = '\\usepackage{ctex}\n'
|
add_ctex = "\\usepackage{ctex}\n"
|
||||||
add_url = '\\usepackage{url}\n' if '{url}' not in main_file else ''
|
add_url = "\\usepackage{url}\n" if "{url}" not in main_file else ""
|
||||||
main_file = main_file[:position] + add_ctex + add_url + main_file[position:]
|
main_file = main_file[:position] + add_ctex + add_url + main_file[position:]
|
||||||
# fontset=windows
|
# fontset=windows
|
||||||
import platform
|
import platform
|
||||||
main_file = re.sub(r"\\documentclass\[(.*?)\]{(.*?)}", r"\\documentclass[\1,fontset=windows,UTF8]{\2}",main_file)
|
|
||||||
main_file = re.sub(r"\\documentclass{(.*?)}", r"\\documentclass[fontset=windows,UTF8]{\1}",main_file)
|
main_file = re.sub(
|
||||||
|
r"\\documentclass\[(.*?)\]{(.*?)}",
|
||||||
|
r"\\documentclass[\1,fontset=windows,UTF8]{\2}",
|
||||||
|
main_file,
|
||||||
|
)
|
||||||
|
main_file = re.sub(
|
||||||
|
r"\\documentclass{(.*?)}",
|
||||||
|
r"\\documentclass[fontset=windows,UTF8]{\1}",
|
||||||
|
main_file,
|
||||||
|
)
|
||||||
# find paper abstract
|
# find paper abstract
|
||||||
pattern_opt1 = re.compile(r'\\begin\{abstract\}.*\n')
|
pattern_opt1 = re.compile(r"\\begin\{abstract\}.*\n")
|
||||||
pattern_opt2 = re.compile(r"\\abstract\{(.*?)\}", flags=re.DOTALL)
|
pattern_opt2 = re.compile(r"\\abstract\{(.*?)\}", flags=re.DOTALL)
|
||||||
match_opt1 = pattern_opt1.search(main_file)
|
match_opt1 = pattern_opt1.search(main_file)
|
||||||
match_opt2 = pattern_opt2.search(main_file)
|
match_opt2 = pattern_opt2.search(main_file)
|
||||||
@ -385,7 +466,9 @@ def merge_tex_files(project_foler, main_file, mode):
|
|||||||
main_file = insert_abstract(main_file)
|
main_file = insert_abstract(main_file)
|
||||||
match_opt1 = pattern_opt1.search(main_file)
|
match_opt1 = pattern_opt1.search(main_file)
|
||||||
match_opt2 = pattern_opt2.search(main_file)
|
match_opt2 = pattern_opt2.search(main_file)
|
||||||
assert (match_opt1 is not None) or (match_opt2 is not None), "Cannot find paper abstract section!"
|
assert (match_opt1 is not None) or (
|
||||||
|
match_opt2 is not None
|
||||||
|
), "Cannot find paper abstract section!"
|
||||||
return main_file
|
return main_file
|
||||||
|
|
||||||
|
|
||||||
@ -395,6 +478,7 @@ The GPT-Academic program cannot find abstract section in this paper.
|
|||||||
\end{abstract}
|
\end{abstract}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def insert_abstract(tex_content):
|
def insert_abstract(tex_content):
|
||||||
if "\\maketitle" in tex_content:
|
if "\\maketitle" in tex_content:
|
||||||
# find the position of "\maketitle"
|
# find the position of "\maketitle"
|
||||||
@ -402,7 +486,13 @@ def insert_abstract(tex_content):
|
|||||||
# find the nearest ending line
|
# find the nearest ending line
|
||||||
end_line_index = tex_content.find("\n", find_index)
|
end_line_index = tex_content.find("\n", find_index)
|
||||||
# insert "abs_str" on the next line
|
# insert "abs_str" on the next line
|
||||||
modified_tex = tex_content[:end_line_index+1] + '\n\n' + insert_missing_abs_str + '\n\n' + tex_content[end_line_index+1:]
|
modified_tex = (
|
||||||
|
tex_content[: end_line_index + 1]
|
||||||
|
+ "\n\n"
|
||||||
|
+ insert_missing_abs_str
|
||||||
|
+ "\n\n"
|
||||||
|
+ tex_content[end_line_index + 1 :]
|
||||||
|
)
|
||||||
return modified_tex
|
return modified_tex
|
||||||
elif r"\begin{document}" in tex_content:
|
elif r"\begin{document}" in tex_content:
|
||||||
# find the position of "\maketitle"
|
# find the position of "\maketitle"
|
||||||
@ -410,16 +500,25 @@ def insert_abstract(tex_content):
|
|||||||
# find the nearest ending line
|
# find the nearest ending line
|
||||||
end_line_index = tex_content.find("\n", find_index)
|
end_line_index = tex_content.find("\n", find_index)
|
||||||
# insert "abs_str" on the next line
|
# insert "abs_str" on the next line
|
||||||
modified_tex = tex_content[:end_line_index+1] + '\n\n' + insert_missing_abs_str + '\n\n' + tex_content[end_line_index+1:]
|
modified_tex = (
|
||||||
|
tex_content[: end_line_index + 1]
|
||||||
|
+ "\n\n"
|
||||||
|
+ insert_missing_abs_str
|
||||||
|
+ "\n\n"
|
||||||
|
+ tex_content[end_line_index + 1 :]
|
||||||
|
)
|
||||||
return modified_tex
|
return modified_tex
|
||||||
else:
|
else:
|
||||||
return tex_content
|
return tex_content
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||||
Post process
|
Post process
|
||||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def mod_inbraket(match):
|
def mod_inbraket(match):
|
||||||
"""
|
"""
|
||||||
为啥chatgpt会把cite里面的逗号换成中文逗号呀
|
为啥chatgpt会把cite里面的逗号换成中文逗号呀
|
||||||
@ -428,11 +527,12 @@ def mod_inbraket(match):
|
|||||||
cmd = match.group(1)
|
cmd = match.group(1)
|
||||||
str_to_modify = match.group(2)
|
str_to_modify = match.group(2)
|
||||||
# modify the matched string
|
# modify the matched string
|
||||||
str_to_modify = str_to_modify.replace(':', ':') # 前面是中文冒号,后面是英文冒号
|
str_to_modify = str_to_modify.replace(":", ":") # 前面是中文冒号,后面是英文冒号
|
||||||
str_to_modify = str_to_modify.replace(',', ',') # 前面是中文逗号,后面是英文逗号
|
str_to_modify = str_to_modify.replace(",", ",") # 前面是中文逗号,后面是英文逗号
|
||||||
# str_to_modify = 'BOOM'
|
# str_to_modify = 'BOOM'
|
||||||
return "\\" + cmd + "{" + str_to_modify + "}"
|
return "\\" + cmd + "{" + str_to_modify + "}"
|
||||||
|
|
||||||
|
|
||||||
def fix_content(final_tex, node_string):
|
def fix_content(final_tex, node_string):
|
||||||
"""
|
"""
|
||||||
Fix common GPT errors to increase success rate
|
Fix common GPT errors to increase success rate
|
||||||
@ -444,9 +544,9 @@ def fix_content(final_tex, node_string):
|
|||||||
|
|
||||||
if "Traceback" in final_tex and "[Local Message]" in final_tex:
|
if "Traceback" in final_tex and "[Local Message]" in final_tex:
|
||||||
final_tex = node_string # 出问题了,还原原文
|
final_tex = node_string # 出问题了,还原原文
|
||||||
if node_string.count('\\begin') != final_tex.count('\\begin'):
|
if node_string.count("\\begin") != final_tex.count("\\begin"):
|
||||||
final_tex = node_string # 出问题了,还原原文
|
final_tex = node_string # 出问题了,还原原文
|
||||||
if node_string.count('\_') > 0 and node_string.count('\_') > final_tex.count('\_'):
|
if node_string.count("\_") > 0 and node_string.count("\_") > final_tex.count("\_"):
|
||||||
# walk and replace any _ without \
|
# walk and replace any _ without \
|
||||||
final_tex = re.sub(r"(?<!\\)_", "\\_", final_tex)
|
final_tex = re.sub(r"(?<!\\)_", "\\_", final_tex)
|
||||||
|
|
||||||
@ -454,24 +554,32 @@ def fix_content(final_tex, node_string):
|
|||||||
# this function count the number of { and }
|
# this function count the number of { and }
|
||||||
brace_level = 0
|
brace_level = 0
|
||||||
for c in string:
|
for c in string:
|
||||||
if c == "{": brace_level += 1
|
if c == "{":
|
||||||
elif c == "}": brace_level -= 1
|
brace_level += 1
|
||||||
|
elif c == "}":
|
||||||
|
brace_level -= 1
|
||||||
return brace_level
|
return brace_level
|
||||||
|
|
||||||
def join_most(tex_t, tex_o):
|
def join_most(tex_t, tex_o):
|
||||||
# this function join translated string and original string when something goes wrong
|
# this function join translated string and original string when something goes wrong
|
||||||
p_t = 0
|
p_t = 0
|
||||||
p_o = 0
|
p_o = 0
|
||||||
|
|
||||||
def find_next(string, chars, begin):
|
def find_next(string, chars, begin):
|
||||||
p = begin
|
p = begin
|
||||||
while p < len(string):
|
while p < len(string):
|
||||||
if string[p] in chars: return p, string[p]
|
if string[p] in chars:
|
||||||
|
return p, string[p]
|
||||||
p += 1
|
p += 1
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
res1, char = find_next(tex_o, ['{','}'], p_o)
|
res1, char = find_next(tex_o, ["{", "}"], p_o)
|
||||||
if res1 is None: break
|
if res1 is None:
|
||||||
|
break
|
||||||
res2, char = find_next(tex_t, [char], p_t)
|
res2, char = find_next(tex_t, [char], p_t)
|
||||||
if res2 is None: break
|
if res2 is None:
|
||||||
|
break
|
||||||
p_o = res1 + 1
|
p_o = res1 + 1
|
||||||
p_t = res2 + 1
|
p_t = res2 + 1
|
||||||
return tex_t[:p_t] + tex_o[p_o:]
|
return tex_t[:p_t] + tex_o[p_o:]
|
||||||
@ -481,9 +589,13 @@ def fix_content(final_tex, node_string):
|
|||||||
final_tex = join_most(final_tex, node_string)
|
final_tex = join_most(final_tex, node_string)
|
||||||
return final_tex
|
return final_tex
|
||||||
|
|
||||||
|
|
||||||
def compile_latex_with_timeout(command, cwd, timeout=60):
|
def compile_latex_with_timeout(command, cwd, timeout=60):
|
||||||
import subprocess
|
import subprocess
|
||||||
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd)
|
|
||||||
|
process = subprocess.Popen(
|
||||||
|
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
stdout, stderr = process.communicate(timeout=timeout)
|
stdout, stderr = process.communicate(timeout=timeout)
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
@ -493,43 +605,52 @@ def compile_latex_with_timeout(command, cwd, timeout=60):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def run_in_subprocess_wrapper_func(func, args, kwargs, return_dict, exception_dict):
|
def run_in_subprocess_wrapper_func(func, args, kwargs, return_dict, exception_dict):
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
return_dict['result'] = result
|
return_dict["result"] = result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
exception_dict['exception'] = exc_info
|
exception_dict["exception"] = exc_info
|
||||||
|
|
||||||
|
|
||||||
def run_in_subprocess(func):
|
def run_in_subprocess(func):
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
return_dict = multiprocessing.Manager().dict()
|
return_dict = multiprocessing.Manager().dict()
|
||||||
exception_dict = multiprocessing.Manager().dict()
|
exception_dict = multiprocessing.Manager().dict()
|
||||||
process = multiprocessing.Process(target=run_in_subprocess_wrapper_func,
|
process = multiprocessing.Process(
|
||||||
args=(func, args, kwargs, return_dict, exception_dict))
|
target=run_in_subprocess_wrapper_func,
|
||||||
|
args=(func, args, kwargs, return_dict, exception_dict),
|
||||||
|
)
|
||||||
process.start()
|
process.start()
|
||||||
process.join()
|
process.join()
|
||||||
process.close()
|
process.close()
|
||||||
if 'exception' in exception_dict:
|
if "exception" in exception_dict:
|
||||||
# ooops, the subprocess ran into an exception
|
# ooops, the subprocess ran into an exception
|
||||||
exc_info = exception_dict['exception']
|
exc_info = exception_dict["exception"]
|
||||||
raise exc_info[1].with_traceback(exc_info[2])
|
raise exc_info[1].with_traceback(exc_info[2])
|
||||||
if 'result' in return_dict.keys():
|
if "result" in return_dict.keys():
|
||||||
# If the subprocess ran successfully, return the result
|
# If the subprocess ran successfully, return the result
|
||||||
return return_dict['result']
|
return return_dict["result"]
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def _merge_pdfs(pdf1_path, pdf2_path, output_path):
|
def _merge_pdfs(pdf1_path, pdf2_path, output_path):
|
||||||
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
|
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
|
||||||
|
|
||||||
Percent = 0.95
|
Percent = 0.95
|
||||||
# raise RuntimeError('PyPDF2 has a serious memory leak problem, please use other tools to merge PDF files.')
|
# raise RuntimeError('PyPDF2 has a serious memory leak problem, please use other tools to merge PDF files.')
|
||||||
# Open the first PDF file
|
# Open the first PDF file
|
||||||
with open(pdf1_path, 'rb') as pdf1_file:
|
with open(pdf1_path, "rb") as pdf1_file:
|
||||||
pdf1_reader = PyPDF2.PdfFileReader(pdf1_file)
|
pdf1_reader = PyPDF2.PdfFileReader(pdf1_file)
|
||||||
# Open the second PDF file
|
# Open the second PDF file
|
||||||
with open(pdf2_path, 'rb') as pdf2_file:
|
with open(pdf2_path, "rb") as pdf2_file:
|
||||||
pdf2_reader = PyPDF2.PdfFileReader(pdf2_file)
|
pdf2_reader = PyPDF2.PdfFileReader(pdf2_file)
|
||||||
# Create a new PDF file to store the merged pages
|
# Create a new PDF file to store the merged pages
|
||||||
output_writer = PyPDF2.PdfFileWriter()
|
output_writer = PyPDF2.PdfFileWriter()
|
||||||
@ -549,14 +670,25 @@ def _merge_pdfs(pdf1_path, pdf2_path, output_path):
|
|||||||
page2 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
|
page2 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
|
||||||
# Create a new empty page with double width
|
# Create a new empty page with double width
|
||||||
new_page = PyPDF2.PageObject.createBlankPage(
|
new_page = PyPDF2.PageObject.createBlankPage(
|
||||||
width = int(int(page1.mediaBox.getWidth()) + int(page2.mediaBox.getWidth()) * Percent),
|
width=int(
|
||||||
height = max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight())
|
int(page1.mediaBox.getWidth())
|
||||||
|
+ int(page2.mediaBox.getWidth()) * Percent
|
||||||
|
),
|
||||||
|
height=max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight()),
|
||||||
)
|
)
|
||||||
new_page.mergeTranslatedPage(page1, 0, 0)
|
new_page.mergeTranslatedPage(page1, 0, 0)
|
||||||
new_page.mergeTranslatedPage(page2, int(int(page1.mediaBox.getWidth())-int(page2.mediaBox.getWidth())* (1-Percent)), 0)
|
new_page.mergeTranslatedPage(
|
||||||
|
page2,
|
||||||
|
int(
|
||||||
|
int(page1.mediaBox.getWidth())
|
||||||
|
- int(page2.mediaBox.getWidth()) * (1 - Percent)
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
output_writer.addPage(new_page)
|
output_writer.addPage(new_page)
|
||||||
# Save the merged PDF file
|
# Save the merged PDF file
|
||||||
with open(output_path, 'wb') as output_file:
|
with open(output_path, "wb") as output_file:
|
||||||
output_writer.write(output_file)
|
output_writer.write(output_file)
|
||||||
|
|
||||||
|
|
||||||
merge_pdfs = run_in_subprocess(_merge_pdfs) # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
|
merge_pdfs = run_in_subprocess(_merge_pdfs) # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
|
||||||
|
@ -352,9 +352,9 @@ def step_1_core_key_translate():
|
|||||||
chinese_core_keys_norepeat_mapping.update({k:cached_translation[k]})
|
chinese_core_keys_norepeat_mapping.update({k:cached_translation[k]})
|
||||||
chinese_core_keys_norepeat_mapping = dict(sorted(chinese_core_keys_norepeat_mapping.items(), key=lambda x: -len(x[0])))
|
chinese_core_keys_norepeat_mapping = dict(sorted(chinese_core_keys_norepeat_mapping.items(), key=lambda x: -len(x[0])))
|
||||||
|
|
||||||
# ===============================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
# copy
|
# copy
|
||||||
# ===============================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
def copy_source_code():
|
def copy_source_code():
|
||||||
|
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
@ -367,9 +367,9 @@ def step_1_core_key_translate():
|
|||||||
shutil.copytree('./', backup_dir, ignore=lambda x, y: blacklist)
|
shutil.copytree('./', backup_dir, ignore=lambda x, y: blacklist)
|
||||||
copy_source_code()
|
copy_source_code()
|
||||||
|
|
||||||
# ===============================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
# primary key replace
|
# primary key replace
|
||||||
# ===============================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
directory_path = f'./multi-language/{LANG}/'
|
directory_path = f'./multi-language/{LANG}/'
|
||||||
for root, dirs, files in os.walk(directory_path):
|
for root, dirs, files in os.walk(directory_path):
|
||||||
for file in files:
|
for file in files:
|
||||||
@ -389,9 +389,9 @@ def step_1_core_key_translate():
|
|||||||
|
|
||||||
def step_2_core_key_translate():
|
def step_2_core_key_translate():
|
||||||
|
|
||||||
# =================================================================================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||||
# step2
|
# step2
|
||||||
# =================================================================================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
||||||
|
|
||||||
def load_string(strings, string_input):
|
def load_string(strings, string_input):
|
||||||
string_ = string_input.strip().strip(',').strip().strip('.').strip()
|
string_ = string_input.strip().strip(',').strip().strip('.').strip()
|
||||||
@ -492,9 +492,9 @@ def step_2_core_key_translate():
|
|||||||
cached_translation.update(read_map_from_json(language=LANG_STD))
|
cached_translation.update(read_map_from_json(language=LANG_STD))
|
||||||
cached_translation = dict(sorted(cached_translation.items(), key=lambda x: -len(x[0])))
|
cached_translation = dict(sorted(cached_translation.items(), key=lambda x: -len(x[0])))
|
||||||
|
|
||||||
# ===============================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
# literal key replace
|
# literal key replace
|
||||||
# ===============================================
|
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
directory_path = f'./multi-language/{LANG}/'
|
directory_path = f'./multi-language/{LANG}/'
|
||||||
for root, dirs, files in os.walk(directory_path):
|
for root, dirs, files in os.walk(directory_path):
|
||||||
for file in files:
|
for file in files:
|
||||||
|
@ -244,7 +244,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
if has_choices and not choice_valid:
|
if has_choices and not choice_valid:
|
||||||
# 一些垃圾第三方接口的出现这样的错误
|
# 一些垃圾第三方接口的出现这样的错误
|
||||||
continue
|
continue
|
||||||
if len(chunk_decoded) > 0 and (chunkjson is None):
|
if ('data: [DONE]' not in chunk_decoded) and len(chunk_decoded) > 0 and (chunkjson is None):
|
||||||
# 传递进来一些奇怪的东西
|
# 传递进来一些奇怪的东西
|
||||||
raise ValueError(f'无法读取以下数据,请检查配置。\n\n{chunk_decoded}')
|
raise ValueError(f'无法读取以下数据,请检查配置。\n\n{chunk_decoded}')
|
||||||
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
# 前者是API2D的结束条件,后者是OPENAI的结束条件
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第一部分:来自EdgeGPT.py
|
第一部分:来自EdgeGPT.py
|
||||||
https://github.com/acheong08/EdgeGPT
|
https://github.com/acheong08/EdgeGPT
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
from .edge_gpt_free import Chatbot as NewbingChatbot
|
from .edge_gpt_free import Chatbot as NewbingChatbot
|
||||||
|
|
||||||
load_message = "等待NewBing响应。"
|
load_message = "等待NewBing响应。"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第二部分:子进程Worker(调用主体)
|
第二部分:子进程Worker(调用主体)
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
@ -22,19 +23,30 @@ import threading
|
|||||||
from toolbox import update_ui, get_conf, trimmed_format_exc
|
from toolbox import update_ui, get_conf, trimmed_format_exc
|
||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
|
|
||||||
|
|
||||||
def preprocess_newbing_out(s):
|
def preprocess_newbing_out(s):
|
||||||
pattern = r'\^(\d+)\^' # 匹配^数字^
|
pattern = r"\^(\d+)\^" # 匹配^数字^
|
||||||
sub = lambda m: '('+m.group(1)+')' # 将匹配到的数字作为替换值
|
sub = lambda m: "(" + m.group(1) + ")" # 将匹配到的数字作为替换值
|
||||||
result = re.sub(pattern, sub, s) # 替换操作
|
result = re.sub(pattern, sub, s) # 替换操作
|
||||||
if '[1]' in result:
|
if "[1]" in result:
|
||||||
result += '\n\n```reference\n' + "\n".join([r for r in result.split('\n') if r.startswith('[')]) + '\n```\n'
|
result += (
|
||||||
|
"\n\n```reference\n"
|
||||||
|
+ "\n".join([r for r in result.split("\n") if r.startswith("[")])
|
||||||
|
+ "\n```\n"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def preprocess_newbing_out_simple(result):
|
def preprocess_newbing_out_simple(result):
|
||||||
if '[1]' in result:
|
if "[1]" in result:
|
||||||
result += '\n\n```reference\n' + "\n".join([r for r in result.split('\n') if r.startswith('[')]) + '\n```\n'
|
result += (
|
||||||
|
"\n\n```reference\n"
|
||||||
|
+ "\n".join([r for r in result.split("\n") if r.startswith("[")])
|
||||||
|
+ "\n```\n"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class NewBingHandle(Process):
|
class NewBingHandle(Process):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(daemon=True)
|
super().__init__(daemon=True)
|
||||||
@ -51,6 +63,7 @@ class NewBingHandle(Process):
|
|||||||
try:
|
try:
|
||||||
self.success = False
|
self.success = False
|
||||||
import certifi, httpx, rich
|
import certifi, httpx, rich
|
||||||
|
|
||||||
self.info = "依赖检测通过,等待NewBing响应。注意目前不能多人同时调用NewBing接口(有线程锁),否则将导致每个人的NewBing问询历史互相渗透。调用NewBing时,会自动使用已配置的代理。"
|
self.info = "依赖检测通过,等待NewBing响应。注意目前不能多人同时调用NewBing接口(有线程锁),否则将导致每个人的NewBing问询历史互相渗透。调用NewBing时,会自动使用已配置的代理。"
|
||||||
self.success = True
|
self.success = True
|
||||||
except:
|
except:
|
||||||
@ -62,15 +75,16 @@ class NewBingHandle(Process):
|
|||||||
|
|
||||||
async def async_run(self):
|
async def async_run(self):
|
||||||
# 读取配置
|
# 读取配置
|
||||||
NEWBING_STYLE = get_conf('NEWBING_STYLE')
|
NEWBING_STYLE = get_conf("NEWBING_STYLE")
|
||||||
from request_llms.bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
endpoint = model_info['newbing']['endpoint']
|
|
||||||
|
endpoint = model_info["newbing"]["endpoint"]
|
||||||
while True:
|
while True:
|
||||||
# 等待
|
# 等待
|
||||||
kwargs = self.child.recv()
|
kwargs = self.child.recv()
|
||||||
question=kwargs['query']
|
question = kwargs["query"]
|
||||||
history=kwargs['history']
|
history = kwargs["history"]
|
||||||
system_prompt=kwargs['system_prompt']
|
system_prompt = kwargs["system_prompt"]
|
||||||
|
|
||||||
# 是否重置
|
# 是否重置
|
||||||
if len(self.local_history) > 0 and len(history) == 0:
|
if len(self.local_history) > 0 and len(history) == 0:
|
||||||
@ -81,19 +95,19 @@ class NewBingHandle(Process):
|
|||||||
prompt = ""
|
prompt = ""
|
||||||
if system_prompt not in self.local_history:
|
if system_prompt not in self.local_history:
|
||||||
self.local_history.append(system_prompt)
|
self.local_history.append(system_prompt)
|
||||||
prompt += system_prompt + '\n'
|
prompt += system_prompt + "\n"
|
||||||
|
|
||||||
# 追加历史
|
# 追加历史
|
||||||
for ab in history:
|
for ab in history:
|
||||||
a, b = ab
|
a, b = ab
|
||||||
if a not in self.local_history:
|
if a not in self.local_history:
|
||||||
self.local_history.append(a)
|
self.local_history.append(a)
|
||||||
prompt += a + '\n'
|
prompt += a + "\n"
|
||||||
|
|
||||||
# 问题
|
# 问题
|
||||||
prompt += question
|
prompt += question
|
||||||
self.local_history.append(question)
|
self.local_history.append(question)
|
||||||
print('question:', prompt)
|
print("question:", prompt)
|
||||||
# 提交
|
# 提交
|
||||||
async for final, response in self.newbing_model.ask_stream(
|
async for final, response in self.newbing_model.ask_stream(
|
||||||
prompt=question,
|
prompt=question,
|
||||||
@ -104,11 +118,10 @@ class NewBingHandle(Process):
|
|||||||
print(response)
|
print(response)
|
||||||
self.child.send(str(response))
|
self.child.send(str(response))
|
||||||
else:
|
else:
|
||||||
print('-------- receive final ---------')
|
print("-------- receive final ---------")
|
||||||
self.child.send('[Finish]')
|
self.child.send("[Finish]")
|
||||||
# self.local_history.append(response)
|
# self.local_history.append(response)
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
"""
|
||||||
这个函数运行在子进程
|
这个函数运行在子进程
|
||||||
@ -118,32 +131,37 @@ class NewBingHandle(Process):
|
|||||||
self.local_history = []
|
self.local_history = []
|
||||||
if (self.newbing_model is None) or (not self.success):
|
if (self.newbing_model is None) or (not self.success):
|
||||||
# 代理设置
|
# 代理设置
|
||||||
proxies, NEWBING_COOKIES = get_conf('proxies', 'NEWBING_COOKIES')
|
proxies, NEWBING_COOKIES = get_conf("proxies", "NEWBING_COOKIES")
|
||||||
if proxies is None:
|
if proxies is None:
|
||||||
self.proxies_https = None
|
self.proxies_https = None
|
||||||
else:
|
else:
|
||||||
self.proxies_https = proxies['https']
|
self.proxies_https = proxies["https"]
|
||||||
|
|
||||||
if (NEWBING_COOKIES is not None) and len(NEWBING_COOKIES) > 100:
|
if (NEWBING_COOKIES is not None) and len(NEWBING_COOKIES) > 100:
|
||||||
try:
|
try:
|
||||||
cookies = json.loads(NEWBING_COOKIES)
|
cookies = json.loads(NEWBING_COOKIES)
|
||||||
except:
|
except:
|
||||||
self.success = False
|
self.success = False
|
||||||
tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n'
|
tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
|
||||||
self.child.send(f'[Local Message] NEWBING_COOKIES未填写或有格式错误。')
|
self.child.send(f"[Local Message] NEWBING_COOKIES未填写或有格式错误。")
|
||||||
self.child.send('[Fail]'); self.child.send('[Finish]')
|
self.child.send("[Fail]")
|
||||||
|
self.child.send("[Finish]")
|
||||||
raise RuntimeError(f"NEWBING_COOKIES未填写或有格式错误。")
|
raise RuntimeError(f"NEWBING_COOKIES未填写或有格式错误。")
|
||||||
else:
|
else:
|
||||||
cookies = None
|
cookies = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.newbing_model = NewbingChatbot(proxy=self.proxies_https, cookies=cookies)
|
self.newbing_model = NewbingChatbot(
|
||||||
|
proxy=self.proxies_https, cookies=cookies
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
self.success = False
|
self.success = False
|
||||||
tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n'
|
tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
|
||||||
self.child.send(f'[Local Message] 不能加载Newbing组件,请注意Newbing组件已不再维护。{tb_str}')
|
self.child.send(
|
||||||
self.child.send('[Fail]')
|
f"[Local Message] 不能加载Newbing组件,请注意Newbing组件已不再维护。{tb_str}"
|
||||||
self.child.send('[Finish]')
|
)
|
||||||
|
self.child.send("[Fail]")
|
||||||
|
self.child.send("[Finish]")
|
||||||
raise RuntimeError(f"不能加载Newbing组件,请注意Newbing组件已不再维护。")
|
raise RuntimeError(f"不能加载Newbing组件,请注意Newbing组件已不再维护。")
|
||||||
|
|
||||||
self.success = True
|
self.success = True
|
||||||
@ -151,10 +169,12 @@ class NewBingHandle(Process):
|
|||||||
# 进入任务等待状态
|
# 进入任务等待状态
|
||||||
asyncio.run(self.async_run())
|
asyncio.run(self.async_run())
|
||||||
except Exception:
|
except Exception:
|
||||||
tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n'
|
tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
|
||||||
self.child.send(f'[Local Message] Newbing 请求失败,报错信息如下. 如果是与网络相关的问题,建议更换代理协议(推荐http)或代理节点 {tb_str}.')
|
self.child.send(
|
||||||
self.child.send('[Fail]')
|
f"[Local Message] Newbing 请求失败,报错信息如下. 如果是与网络相关的问题,建议更换代理协议(推荐http)或代理节点 {tb_str}."
|
||||||
self.child.send('[Finish]')
|
)
|
||||||
|
self.child.send("[Fail]")
|
||||||
|
self.child.send("[Finish]")
|
||||||
|
|
||||||
def stream_chat(self, **kwargs):
|
def stream_chat(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -164,21 +184,33 @@ class NewBingHandle(Process):
|
|||||||
self.parent.send(kwargs) # 请求子进程
|
self.parent.send(kwargs) # 请求子进程
|
||||||
while True:
|
while True:
|
||||||
res = self.parent.recv() # 等待newbing回复的片段
|
res = self.parent.recv() # 等待newbing回复的片段
|
||||||
if res == '[Finish]': break # 结束
|
if res == "[Finish]":
|
||||||
elif res == '[Fail]': self.success = False; break # 失败
|
break # 结束
|
||||||
else: yield res # newbing回复的片段
|
elif res == "[Fail]":
|
||||||
|
self.success = False
|
||||||
|
break # 失败
|
||||||
|
else:
|
||||||
|
yield res # newbing回复的片段
|
||||||
self.threadLock.release() # 释放线程锁
|
self.threadLock.release() # 释放线程锁
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第三部分:主进程统一调用函数接口
|
第三部分:主进程统一调用函数接口
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
global newbingfree_handle
|
global newbingfree_handle
|
||||||
newbingfree_handle = None
|
newbingfree_handle = None
|
||||||
|
|
||||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
|
|
||||||
|
def predict_no_ui_long_connection(
|
||||||
|
inputs,
|
||||||
|
llm_kwargs,
|
||||||
|
history=[],
|
||||||
|
sys_prompt="",
|
||||||
|
observe_window=[],
|
||||||
|
console_slience=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
多线程方法
|
多线程方法
|
||||||
函数的说明请见 request_llms/bridge_all.py
|
函数的说明请见 request_llms/bridge_all.py
|
||||||
@ -186,7 +218,8 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
global newbingfree_handle
|
global newbingfree_handle
|
||||||
if (newbingfree_handle is None) or (not newbingfree_handle.success):
|
if (newbingfree_handle is None) or (not newbingfree_handle.success):
|
||||||
newbingfree_handle = NewBingHandle()
|
newbingfree_handle = NewBingHandle()
|
||||||
if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + newbingfree_handle.info
|
if len(observe_window) >= 1:
|
||||||
|
observe_window[0] = load_message + "\n\n" + newbingfree_handle.info
|
||||||
if not newbingfree_handle.success:
|
if not newbingfree_handle.success:
|
||||||
error = newbingfree_handle.info
|
error = newbingfree_handle.info
|
||||||
newbingfree_handle = None
|
newbingfree_handle = None
|
||||||
@ -199,15 +232,34 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
|
|
||||||
watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
|
watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
|
||||||
response = ""
|
response = ""
|
||||||
if len(observe_window) >= 1: observe_window[0] = "[Local Message] 等待NewBing响应中 ..."
|
if len(observe_window) >= 1:
|
||||||
for response in newbingfree_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=sys_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
|
observe_window[0] = "[Local Message] 等待NewBing响应中 ..."
|
||||||
if len(observe_window) >= 1: observe_window[0] = preprocess_newbing_out_simple(response)
|
for response in newbingfree_handle.stream_chat(
|
||||||
|
query=inputs,
|
||||||
|
history=history_feedin,
|
||||||
|
system_prompt=sys_prompt,
|
||||||
|
max_length=llm_kwargs["max_length"],
|
||||||
|
top_p=llm_kwargs["top_p"],
|
||||||
|
temperature=llm_kwargs["temperature"],
|
||||||
|
):
|
||||||
|
if len(observe_window) >= 1:
|
||||||
|
observe_window[0] = preprocess_newbing_out_simple(response)
|
||||||
if len(observe_window) >= 2:
|
if len(observe_window) >= 2:
|
||||||
if (time.time() - observe_window[1]) > watch_dog_patience:
|
if (time.time() - observe_window[1]) > watch_dog_patience:
|
||||||
raise RuntimeError("程序终止。")
|
raise RuntimeError("程序终止。")
|
||||||
return preprocess_newbing_out_simple(response)
|
return preprocess_newbing_out_simple(response)
|
||||||
|
|
||||||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
|
|
||||||
|
def predict(
|
||||||
|
inputs,
|
||||||
|
llm_kwargs,
|
||||||
|
plugin_kwargs,
|
||||||
|
chatbot,
|
||||||
|
history=[],
|
||||||
|
system_prompt="",
|
||||||
|
stream=True,
|
||||||
|
additional_fn=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
单线程方法
|
单线程方法
|
||||||
函数的说明请见 request_llms/bridge_all.py
|
函数的说明请见 request_llms/bridge_all.py
|
||||||
@ -225,7 +277,10 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
|
|
||||||
if additional_fn is not None:
|
if additional_fn is not None:
|
||||||
from core_functional import handle_core_functionality
|
from core_functional import handle_core_functionality
|
||||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
|
||||||
|
inputs, history = handle_core_functionality(
|
||||||
|
additional_fn, inputs, history, chatbot
|
||||||
|
)
|
||||||
|
|
||||||
history_feedin = []
|
history_feedin = []
|
||||||
for i in range(len(history) // 2):
|
for i in range(len(history) // 2):
|
||||||
@ -233,13 +288,24 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
|
|
||||||
chatbot[-1] = (inputs, "[Local Message] 等待NewBing响应中 ...")
|
chatbot[-1] = (inputs, "[Local Message] 等待NewBing响应中 ...")
|
||||||
response = "[Local Message] 等待NewBing响应中 ..."
|
response = "[Local Message] 等待NewBing响应中 ..."
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="NewBing响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。")
|
yield from update_ui(
|
||||||
for response in newbingfree_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=system_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
|
chatbot=chatbot, history=history, msg="NewBing响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。"
|
||||||
|
)
|
||||||
|
for response in newbingfree_handle.stream_chat(
|
||||||
|
query=inputs,
|
||||||
|
history=history_feedin,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
max_length=llm_kwargs["max_length"],
|
||||||
|
top_p=llm_kwargs["top_p"],
|
||||||
|
temperature=llm_kwargs["temperature"],
|
||||||
|
):
|
||||||
chatbot[-1] = (inputs, preprocess_newbing_out(response))
|
chatbot[-1] = (inputs, preprocess_newbing_out(response))
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="NewBing响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。")
|
yield from update_ui(
|
||||||
if response == "[Local Message] 等待NewBing响应中 ...": response = "[Local Message] NewBing响应异常,请刷新界面重试 ..."
|
chatbot=chatbot, history=history, msg="NewBing响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。"
|
||||||
|
)
|
||||||
|
if response == "[Local Message] 等待NewBing响应中 ...":
|
||||||
|
response = "[Local Message] NewBing响应异常,请刷新界面重试 ..."
|
||||||
history.extend([inputs, response])
|
history.extend([inputs, response])
|
||||||
logging.info(f'[raw_input] {inputs}')
|
logging.info(f"[raw_input] {inputs}")
|
||||||
logging.info(f'[response] {response}')
|
logging.info(f"[response] {response}")
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="完成全部响应,请提交新问题。")
|
yield from update_ui(chatbot=chatbot, history=history, msg="完成全部响应,请提交新问题。")
|
||||||
|
|
||||||
|
@ -7,14 +7,15 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
load_message = "正在加载Claude组件,请稍候..."
|
load_message = "正在加载Claude组件,请稍候..."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第一部分:Slack API Client
|
第一部分:Slack API Client
|
||||||
https://github.com/yokonsan/claude-in-slack-api
|
https://github.com/yokonsan/claude-in-slack-api
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from slack_sdk.errors import SlackApiError
|
from slack_sdk.errors import SlackApiError
|
||||||
@ -33,10 +34,13 @@ try:
|
|||||||
- get_reply():异步方法。循环监听已打开频道的消息,如果收到"Typing…_"结尾的消息说明Claude还在继续输出,否则结束循环。
|
- get_reply():异步方法。循环监听已打开频道的消息,如果收到"Typing…_"结尾的消息说明Claude还在继续输出,否则结束循环。
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CHANNEL_ID = None
|
CHANNEL_ID = None
|
||||||
|
|
||||||
async def open_channel(self):
|
async def open_channel(self):
|
||||||
response = await self.conversations_open(users=get_conf('SLACK_CLAUDE_BOT_ID'))
|
response = await self.conversations_open(
|
||||||
|
users=get_conf("SLACK_CLAUDE_BOT_ID")
|
||||||
|
)
|
||||||
self.CHANNEL_ID = response["channel"]["id"]
|
self.CHANNEL_ID = response["channel"]["id"]
|
||||||
|
|
||||||
async def chat(self, text):
|
async def chat(self, text):
|
||||||
@ -49,9 +53,14 @@ try:
|
|||||||
async def get_slack_messages(self):
|
async def get_slack_messages(self):
|
||||||
try:
|
try:
|
||||||
# TODO:暂时不支持历史消息,因为在同一个频道里存在多人使用时历史消息渗透问题
|
# TODO:暂时不支持历史消息,因为在同一个频道里存在多人使用时历史消息渗透问题
|
||||||
resp = await self.conversations_history(channel=self.CHANNEL_ID, oldest=self.LAST_TS, limit=1)
|
resp = await self.conversations_history(
|
||||||
msg = [msg for msg in resp["messages"]
|
channel=self.CHANNEL_ID, oldest=self.LAST_TS, limit=1
|
||||||
if msg.get("user") == get_conf('SLACK_CLAUDE_BOT_ID')]
|
)
|
||||||
|
msg = [
|
||||||
|
msg
|
||||||
|
for msg in resp["messages"]
|
||||||
|
if msg.get("user") == get_conf("SLACK_CLAUDE_BOT_ID")
|
||||||
|
]
|
||||||
return msg
|
return msg
|
||||||
except (SlackApiError, KeyError) as e:
|
except (SlackApiError, KeyError) as e:
|
||||||
raise RuntimeError(f"获取Slack消息失败。")
|
raise RuntimeError(f"获取Slack消息失败。")
|
||||||
@ -69,13 +78,14 @@ try:
|
|||||||
else:
|
else:
|
||||||
yield True, msg["text"]
|
yield True, msg["text"]
|
||||||
break
|
break
|
||||||
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第二部分:子进程Worker(调用主体)
|
第二部分:子进程Worker(调用主体)
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -96,6 +106,7 @@ class ClaudeHandle(Process):
|
|||||||
try:
|
try:
|
||||||
self.success = False
|
self.success = False
|
||||||
import slack_sdk
|
import slack_sdk
|
||||||
|
|
||||||
self.info = "依赖检测通过,等待Claude响应。注意目前不能多人同时调用Claude接口(有线程锁),否则将导致每个人的Claude问询历史互相渗透。调用Claude时,会自动使用已配置的代理。"
|
self.info = "依赖检测通过,等待Claude响应。注意目前不能多人同时调用Claude接口(有线程锁),否则将导致每个人的Claude问询历史互相渗透。调用Claude时,会自动使用已配置的代理。"
|
||||||
self.success = True
|
self.success = True
|
||||||
except:
|
except:
|
||||||
@ -110,15 +121,15 @@ class ClaudeHandle(Process):
|
|||||||
while True:
|
while True:
|
||||||
# 等待
|
# 等待
|
||||||
kwargs = self.child.recv()
|
kwargs = self.child.recv()
|
||||||
question = kwargs['query']
|
question = kwargs["query"]
|
||||||
history = kwargs['history']
|
history = kwargs["history"]
|
||||||
|
|
||||||
# 开始问问题
|
# 开始问问题
|
||||||
prompt = ""
|
prompt = ""
|
||||||
|
|
||||||
# 问题
|
# 问题
|
||||||
prompt += question
|
prompt += question
|
||||||
print('question:', prompt)
|
print("question:", prompt)
|
||||||
|
|
||||||
# 提交
|
# 提交
|
||||||
await self.claude_model.chat(prompt)
|
await self.claude_model.chat(prompt)
|
||||||
@ -131,11 +142,15 @@ class ClaudeHandle(Process):
|
|||||||
else:
|
else:
|
||||||
# 防止丢失最后一条消息
|
# 防止丢失最后一条消息
|
||||||
slack_msgs = await self.claude_model.get_slack_messages()
|
slack_msgs = await self.claude_model.get_slack_messages()
|
||||||
last_msg = slack_msgs[-1]["text"] if slack_msgs and len(slack_msgs) > 0 else ""
|
last_msg = (
|
||||||
|
slack_msgs[-1]["text"]
|
||||||
|
if slack_msgs and len(slack_msgs) > 0
|
||||||
|
else ""
|
||||||
|
)
|
||||||
if last_msg:
|
if last_msg:
|
||||||
self.child.send(last_msg)
|
self.child.send(last_msg)
|
||||||
print('-------- receive final ---------')
|
print("-------- receive final ---------")
|
||||||
self.child.send('[Finish]')
|
self.child.send("[Finish]")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
"""
|
||||||
@ -146,22 +161,24 @@ class ClaudeHandle(Process):
|
|||||||
self.local_history = []
|
self.local_history = []
|
||||||
if (self.claude_model is None) or (not self.success):
|
if (self.claude_model is None) or (not self.success):
|
||||||
# 代理设置
|
# 代理设置
|
||||||
proxies = get_conf('proxies')
|
proxies = get_conf("proxies")
|
||||||
if proxies is None:
|
if proxies is None:
|
||||||
self.proxies_https = None
|
self.proxies_https = None
|
||||||
else:
|
else:
|
||||||
self.proxies_https = proxies['https']
|
self.proxies_https = proxies["https"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
SLACK_CLAUDE_USER_TOKEN = get_conf('SLACK_CLAUDE_USER_TOKEN')
|
SLACK_CLAUDE_USER_TOKEN = get_conf("SLACK_CLAUDE_USER_TOKEN")
|
||||||
self.claude_model = SlackClient(token=SLACK_CLAUDE_USER_TOKEN, proxy=self.proxies_https)
|
self.claude_model = SlackClient(
|
||||||
print('Claude组件初始化成功。')
|
token=SLACK_CLAUDE_USER_TOKEN, proxy=self.proxies_https
|
||||||
|
)
|
||||||
|
print("Claude组件初始化成功。")
|
||||||
except:
|
except:
|
||||||
self.success = False
|
self.success = False
|
||||||
tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n'
|
tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
|
||||||
self.child.send(f'[Local Message] 不能加载Claude组件。{tb_str}')
|
self.child.send(f"[Local Message] 不能加载Claude组件。{tb_str}")
|
||||||
self.child.send('[Fail]')
|
self.child.send("[Fail]")
|
||||||
self.child.send('[Finish]')
|
self.child.send("[Finish]")
|
||||||
raise RuntimeError(f"不能加载Claude组件。")
|
raise RuntimeError(f"不能加载Claude组件。")
|
||||||
|
|
||||||
self.success = True
|
self.success = True
|
||||||
@ -169,10 +186,10 @@ class ClaudeHandle(Process):
|
|||||||
# 进入任务等待状态
|
# 进入任务等待状态
|
||||||
asyncio.run(self.async_run())
|
asyncio.run(self.async_run())
|
||||||
except Exception:
|
except Exception:
|
||||||
tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n'
|
tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
|
||||||
self.child.send(f'[Local Message] Claude失败 {tb_str}.')
|
self.child.send(f"[Local Message] Claude失败 {tb_str}.")
|
||||||
self.child.send('[Fail]')
|
self.child.send("[Fail]")
|
||||||
self.child.send('[Finish]')
|
self.child.send("[Finish]")
|
||||||
|
|
||||||
def stream_chat(self, **kwargs):
|
def stream_chat(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -182,9 +199,9 @@ class ClaudeHandle(Process):
|
|||||||
self.parent.send(kwargs) # 发送请求到子进程
|
self.parent.send(kwargs) # 发送请求到子进程
|
||||||
while True:
|
while True:
|
||||||
res = self.parent.recv() # 等待Claude回复的片段
|
res = self.parent.recv() # 等待Claude回复的片段
|
||||||
if res == '[Finish]':
|
if res == "[Finish]":
|
||||||
break # 结束
|
break # 结束
|
||||||
elif res == '[Fail]':
|
elif res == "[Fail]":
|
||||||
self.success = False
|
self.success = False
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@ -193,15 +210,22 @@ class ClaudeHandle(Process):
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第三部分:主进程统一调用函数接口
|
第三部分:主进程统一调用函数接口
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
global claude_handle
|
global claude_handle
|
||||||
claude_handle = None
|
claude_handle = None
|
||||||
|
|
||||||
|
|
||||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False):
|
def predict_no_ui_long_connection(
|
||||||
|
inputs,
|
||||||
|
llm_kwargs,
|
||||||
|
history=[],
|
||||||
|
sys_prompt="",
|
||||||
|
observe_window=None,
|
||||||
|
console_slience=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
多线程方法
|
多线程方法
|
||||||
函数的说明请见 request_llms/bridge_all.py
|
函数的说明请见 request_llms/bridge_all.py
|
||||||
@ -223,7 +247,14 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
|
watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
|
||||||
response = ""
|
response = ""
|
||||||
observe_window[0] = "[Local Message] 等待Claude响应中 ..."
|
observe_window[0] = "[Local Message] 等待Claude响应中 ..."
|
||||||
for response in claude_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=sys_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
|
for response in claude_handle.stream_chat(
|
||||||
|
query=inputs,
|
||||||
|
history=history_feedin,
|
||||||
|
system_prompt=sys_prompt,
|
||||||
|
max_length=llm_kwargs["max_length"],
|
||||||
|
top_p=llm_kwargs["top_p"],
|
||||||
|
temperature=llm_kwargs["temperature"],
|
||||||
|
):
|
||||||
observe_window[0] = preprocess_newbing_out_simple(response)
|
observe_window[0] = preprocess_newbing_out_simple(response)
|
||||||
if len(observe_window) >= 2:
|
if len(observe_window) >= 2:
|
||||||
if (time.time() - observe_window[1]) > watch_dog_patience:
|
if (time.time() - observe_window[1]) > watch_dog_patience:
|
||||||
@ -231,7 +262,16 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
|
|||||||
return preprocess_newbing_out_simple(response)
|
return preprocess_newbing_out_simple(response)
|
||||||
|
|
||||||
|
|
||||||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
|
def predict(
|
||||||
|
inputs,
|
||||||
|
llm_kwargs,
|
||||||
|
plugin_kwargs,
|
||||||
|
chatbot,
|
||||||
|
history=[],
|
||||||
|
system_prompt="",
|
||||||
|
stream=True,
|
||||||
|
additional_fn=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
单线程方法
|
单线程方法
|
||||||
函数的说明请见 request_llms/bridge_all.py
|
函数的说明请见 request_llms/bridge_all.py
|
||||||
@ -249,7 +289,10 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
|
|
||||||
if additional_fn is not None:
|
if additional_fn is not None:
|
||||||
from core_functional import handle_core_functionality
|
from core_functional import handle_core_functionality
|
||||||
inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
|
|
||||||
|
inputs, history = handle_core_functionality(
|
||||||
|
additional_fn, inputs, history, chatbot
|
||||||
|
)
|
||||||
|
|
||||||
history_feedin = []
|
history_feedin = []
|
||||||
for i in range(len(history) // 2):
|
for i in range(len(history) // 2):
|
||||||
@ -257,13 +300,19 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|||||||
|
|
||||||
chatbot[-1] = (inputs, "[Local Message] 等待Claude响应中 ...")
|
chatbot[-1] = (inputs, "[Local Message] 等待Claude响应中 ...")
|
||||||
response = "[Local Message] 等待Claude响应中 ..."
|
response = "[Local Message] 等待Claude响应中 ..."
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="Claude响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。")
|
yield from update_ui(
|
||||||
for response in claude_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=system_prompt):
|
chatbot=chatbot, history=history, msg="Claude响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。"
|
||||||
|
)
|
||||||
|
for response in claude_handle.stream_chat(
|
||||||
|
query=inputs, history=history_feedin, system_prompt=system_prompt
|
||||||
|
):
|
||||||
chatbot[-1] = (inputs, preprocess_newbing_out(response))
|
chatbot[-1] = (inputs, preprocess_newbing_out(response))
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="Claude响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。")
|
yield from update_ui(
|
||||||
|
chatbot=chatbot, history=history, msg="Claude响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。"
|
||||||
|
)
|
||||||
if response == "[Local Message] 等待Claude响应中 ...":
|
if response == "[Local Message] 等待Claude响应中 ...":
|
||||||
response = "[Local Message] Claude响应异常,请刷新界面重试 ..."
|
response = "[Local Message] Claude响应异常,请刷新界面重试 ..."
|
||||||
history.extend([inputs, response])
|
history.extend([inputs, response])
|
||||||
logging.info(f'[raw_input] {inputs}')
|
logging.info(f"[raw_input] {inputs}")
|
||||||
logging.info(f'[response] {response}')
|
logging.info(f"[response] {response}")
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="完成全部响应,请提交新问题。")
|
yield from update_ui(chatbot=chatbot, history=history, msg="完成全部响应,请提交新问题。")
|
||||||
|
@ -12,7 +12,7 @@ from toolbox import get_conf, encode_image, get_pictures_list
|
|||||||
proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
|
proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第五部分 一些文件处理方法
|
第五部分 一些文件处理方法
|
||||||
files_filter_handler 根据type过滤文件
|
files_filter_handler 根据type过滤文件
|
||||||
input_encode_handler 提取input中的文件,并解析
|
input_encode_handler 提取input中的文件,并解析
|
||||||
@ -21,6 +21,7 @@ link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
|
|||||||
html_view_blank 超链接
|
html_view_blank 超链接
|
||||||
html_local_file 本地文件取相对路径
|
html_local_file 本地文件取相对路径
|
||||||
to_markdown_tabs 文件list 转换为 md tab
|
to_markdown_tabs 文件list 转换为 md tab
|
||||||
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第一部分:来自EdgeGPT.py
|
第一部分:来自EdgeGPT.py
|
||||||
https://github.com/acheong08/EdgeGPT
|
https://github.com/acheong08/EdgeGPT
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
Main.py
|
Main.py
|
||||||
@ -452,9 +452,11 @@ class _ChatHub:
|
|||||||
ws_cookies = []
|
ws_cookies = []
|
||||||
for cookie in self.cookies:
|
for cookie in self.cookies:
|
||||||
ws_cookies.append(f"{cookie['name']}={cookie['value']}")
|
ws_cookies.append(f"{cookie['name']}={cookie['value']}")
|
||||||
req_header.update({
|
req_header.update(
|
||||||
'Cookie': ';'.join(ws_cookies),
|
{
|
||||||
})
|
"Cookie": ";".join(ws_cookies),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
timeout = aiohttp.ClientTimeout(total=30)
|
timeout = aiohttp.ClientTimeout(total=30)
|
||||||
self.session = aiohttp.ClientSession(timeout=timeout)
|
self.session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
@ -2,6 +2,7 @@ import markdown
|
|||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
|
from textwrap import dedent
|
||||||
from latex2mathml.converter import convert as tex2mathml
|
from latex2mathml.converter import convert as tex2mathml
|
||||||
from functools import wraps, lru_cache
|
from functools import wraps, lru_cache
|
||||||
from shared_utils.config_loader import get_conf as get_conf
|
from shared_utils.config_loader import get_conf as get_conf
|
||||||
@ -32,26 +33,6 @@ def text_divide_paragraph(text):
|
|||||||
text = "</br>".join(lines)
|
text = "</br>".join(lines)
|
||||||
return pre + text + suf
|
return pre + text + suf
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=128) # 使用 lru缓存 加快转换速度
|
|
||||||
def markdown_convertion(txt):
|
|
||||||
"""
|
|
||||||
将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
|
|
||||||
"""
|
|
||||||
pre = '<div class="markdown-body">'
|
|
||||||
suf = '</div>'
|
|
||||||
if txt.startswith(pre) and txt.endswith(suf):
|
|
||||||
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
|
|
||||||
return txt # 已经被转化过,不需要再次转化
|
|
||||||
|
|
||||||
markdown_extension_configs = {
|
|
||||||
'mdx_math': {
|
|
||||||
'enable_dollar_delimiter': True,
|
|
||||||
'use_gitlab_delimiters': False,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
|
|
||||||
|
|
||||||
def tex2mathml_catch_exception(content, *args, **kwargs):
|
def tex2mathml_catch_exception(content, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
content = tex2mathml(content, *args, **kwargs)
|
content = tex2mathml(content, *args, **kwargs)
|
||||||
@ -121,7 +102,8 @@ def markdown_convertion(txt):
|
|||||||
def fix_markdown_indent(txt):
|
def fix_markdown_indent(txt):
|
||||||
# fix markdown indent
|
# fix markdown indent
|
||||||
if (' - ' not in txt) or ('. ' not in txt):
|
if (' - ' not in txt) or ('. ' not in txt):
|
||||||
return txt # do not need to fix, fast escape
|
# do not need to fix, fast escape
|
||||||
|
return txt
|
||||||
# walk through the lines and fix non-standard indentation
|
# walk through the lines and fix non-standard indentation
|
||||||
lines = txt.split("\n")
|
lines = txt.split("\n")
|
||||||
pattern = re.compile(r'^\s+-')
|
pattern = re.compile(r'^\s+-')
|
||||||
@ -137,7 +119,83 @@ def markdown_convertion(txt):
|
|||||||
lines[i] = ' ' * num_spaces_should_be + stripped_string
|
lines[i] = ' ' * num_spaces_should_be + stripped_string
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
FENCED_BLOCK_RE = re.compile(
|
||||||
|
dedent(r'''
|
||||||
|
(?P<fence>^[ \t]*(?:~{3,}|`{3,}))[ ]* # opening fence
|
||||||
|
((\{(?P<attrs>[^\}\n]*)\})| # (optional {attrs} or
|
||||||
|
(\.?(?P<lang>[\w#.+-]*)[ ]*)? # optional (.)lang
|
||||||
|
(hl_lines=(?P<quot>"|')(?P<hl_lines>.*?)(?P=quot)[ ]*)?) # optional hl_lines)
|
||||||
|
\n # newline (end of opening fence)
|
||||||
|
(?P<code>.*?)(?<=\n) # the code block
|
||||||
|
(?P=fence)[ ]*$ # closing fence
|
||||||
|
'''),
|
||||||
|
re.MULTILINE | re.DOTALL | re.VERBOSE
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_line_range(re_match_obj, txt):
|
||||||
|
start_pos, end_pos = re_match_obj.regs[0]
|
||||||
|
num_newlines_before = txt[:start_pos+1].count('\n')
|
||||||
|
line_start = num_newlines_before
|
||||||
|
line_end = num_newlines_before + txt[start_pos:end_pos].count('\n')+1
|
||||||
|
return line_start, line_end
|
||||||
|
|
||||||
|
def fix_code_segment_indent(txt):
|
||||||
|
lines = []
|
||||||
|
change_any = False
|
||||||
|
txt_tmp = txt
|
||||||
|
while True:
|
||||||
|
re_match_obj = FENCED_BLOCK_RE.search(txt_tmp)
|
||||||
|
if not re_match_obj: break
|
||||||
|
if len(lines) == 0: lines = txt.split("\n")
|
||||||
|
|
||||||
|
# 清空 txt_tmp 对应的位置方便下次搜索
|
||||||
|
start_pos, end_pos = re_match_obj.regs[0]
|
||||||
|
txt_tmp = txt_tmp[:start_pos] + ' '*(end_pos-start_pos) + txt_tmp[end_pos:]
|
||||||
|
line_start, line_end = get_line_range(re_match_obj, txt)
|
||||||
|
|
||||||
|
# 获取公共缩进
|
||||||
|
shared_indent_cnt = 1e5
|
||||||
|
for i in range(line_start, line_end):
|
||||||
|
stripped_string = lines[i].lstrip()
|
||||||
|
num_spaces = len(lines[i]) - len(stripped_string)
|
||||||
|
if num_spaces < shared_indent_cnt:
|
||||||
|
shared_indent_cnt = num_spaces
|
||||||
|
|
||||||
|
# 修复缩进
|
||||||
|
if (shared_indent_cnt < 1e5) and (shared_indent_cnt % 4) == 3:
|
||||||
|
num_spaces_should_be = math.ceil(shared_indent_cnt / 4) * 4
|
||||||
|
for i in range(line_start, line_end):
|
||||||
|
add_n = num_spaces_should_be - shared_indent_cnt
|
||||||
|
lines[i] = ' ' * add_n + lines[i]
|
||||||
|
if not change_any: # 遇到第一个
|
||||||
|
change_any = True
|
||||||
|
|
||||||
|
if change_any:
|
||||||
|
return '\n'.join(lines)
|
||||||
|
else:
|
||||||
|
return txt
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128) # 使用 lru缓存 加快转换速度
|
||||||
|
def markdown_convertion(txt):
|
||||||
|
"""
|
||||||
|
将Markdown格式的文本转换为HTML格式。如果包含数学公式,则先将公式转换为HTML格式。
|
||||||
|
"""
|
||||||
|
pre = '<div class="markdown-body">'
|
||||||
|
suf = '</div>'
|
||||||
|
if txt.startswith(pre) and txt.endswith(suf):
|
||||||
|
# print('警告,输入了已经经过转化的字符串,二次转化可能出问题')
|
||||||
|
return txt # 已经被转化过,不需要再次转化
|
||||||
|
|
||||||
|
markdown_extension_configs = {
|
||||||
|
'mdx_math': {
|
||||||
|
'enable_dollar_delimiter': True,
|
||||||
|
'use_gitlab_delimiters': False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
|
||||||
|
|
||||||
txt = fix_markdown_indent(txt)
|
txt = fix_markdown_indent(txt)
|
||||||
|
txt = fix_code_segment_indent(txt)
|
||||||
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
|
if is_equation(txt): # 有$标识的公式符号,且没有代码段```的标识
|
||||||
# convert everything to html format
|
# convert everything to html format
|
||||||
split = markdown.markdown(text='---')
|
split = markdown.markdown(text='---')
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
接驳void-terminal:
|
接驳void-terminal:
|
||||||
- set_conf: 在运行过程中动态地修改配置
|
- set_conf: 在运行过程中动态地修改配置
|
||||||
- set_multi_conf: 在运行过程中动态地修改多个配置
|
- set_multi_conf: 在运行过程中动态地修改多个配置
|
||||||
@ -9,17 +9,20 @@ import os
|
|||||||
- get_plugin_default_kwargs: 获取插件的默认参数
|
- get_plugin_default_kwargs: 获取插件的默认参数
|
||||||
- get_chat_handle: 获取简单聊天的句柄
|
- get_chat_handle: 获取简单聊天的句柄
|
||||||
- get_chat_default_kwargs: 获取简单聊天的默认参数
|
- get_chat_default_kwargs: 获取简单聊天的默认参数
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_handle(plugin_name):
|
def get_plugin_handle(plugin_name):
|
||||||
"""
|
"""
|
||||||
e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言'
|
e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言'
|
||||||
"""
|
"""
|
||||||
import importlib
|
import importlib
|
||||||
assert '->' in plugin_name, \
|
|
||||||
"Example of plugin_name: crazy_functions.批量Markdown翻译->Markdown翻译指定语言"
|
assert (
|
||||||
module, fn_name = plugin_name.split('->')
|
"->" in plugin_name
|
||||||
|
), "Example of plugin_name: crazy_functions.批量Markdown翻译->Markdown翻译指定语言"
|
||||||
|
module, fn_name = plugin_name.split("->")
|
||||||
f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
|
f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
|
||||||
return f_hot_reload
|
return f_hot_reload
|
||||||
|
|
||||||
@ -29,6 +32,7 @@ def get_chat_handle():
|
|||||||
Get chat function
|
Get chat function
|
||||||
"""
|
"""
|
||||||
from request_llms.bridge_all import predict_no_ui_long_connection
|
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||||
|
|
||||||
return predict_no_ui_long_connection
|
return predict_no_ui_long_connection
|
||||||
|
|
||||||
|
|
||||||
@ -37,13 +41,14 @@ def get_plugin_default_kwargs():
|
|||||||
Get Plugin Default Arguments
|
Get Plugin Default Arguments
|
||||||
"""
|
"""
|
||||||
from toolbox import ChatBotWithCookies, load_chat_cookies
|
from toolbox import ChatBotWithCookies, load_chat_cookies
|
||||||
|
|
||||||
cookies = load_chat_cookies()
|
cookies = load_chat_cookies()
|
||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
'api_key': cookies['api_key'],
|
"api_key": cookies["api_key"],
|
||||||
'llm_model': cookies['llm_model'],
|
"llm_model": cookies["llm_model"],
|
||||||
'top_p': 1.0,
|
"top_p": 1.0,
|
||||||
'max_length': None,
|
"max_length": None,
|
||||||
'temperature': 1.0,
|
"temperature": 1.0,
|
||||||
}
|
}
|
||||||
chatbot = ChatBotWithCookies(llm_kwargs)
|
chatbot = ChatBotWithCookies(llm_kwargs)
|
||||||
|
|
||||||
@ -55,7 +60,7 @@ def get_plugin_default_kwargs():
|
|||||||
"chatbot_with_cookie": chatbot,
|
"chatbot_with_cookie": chatbot,
|
||||||
"history": [],
|
"history": [],
|
||||||
"system_prompt": "You are a good AI.",
|
"system_prompt": "You are a good AI.",
|
||||||
"web_port": None
|
"web_port": None,
|
||||||
}
|
}
|
||||||
return DEFAULT_FN_GROUPS_kwargs
|
return DEFAULT_FN_GROUPS_kwargs
|
||||||
|
|
||||||
@ -65,13 +70,14 @@ def get_chat_default_kwargs():
|
|||||||
Get Chat Default Arguments
|
Get Chat Default Arguments
|
||||||
"""
|
"""
|
||||||
from toolbox import load_chat_cookies
|
from toolbox import load_chat_cookies
|
||||||
|
|
||||||
cookies = load_chat_cookies()
|
cookies = load_chat_cookies()
|
||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
'api_key': cookies['api_key'],
|
"api_key": cookies["api_key"],
|
||||||
'llm_model': cookies['llm_model'],
|
"llm_model": cookies["llm_model"],
|
||||||
'top_p': 1.0,
|
"top_p": 1.0,
|
||||||
'max_length': None,
|
"max_length": None,
|
||||||
'temperature': 1.0,
|
"temperature": 1.0,
|
||||||
}
|
}
|
||||||
default_chat_kwargs = {
|
default_chat_kwargs = {
|
||||||
"inputs": "Hello there, are you ready?",
|
"inputs": "Hello there, are you ready?",
|
||||||
|
@ -1,32 +1,75 @@
|
|||||||
md = """
|
md = """
|
||||||
作为您的写作和编程助手,我可以为您提供以下服务:
|
|
||||||
|
|
||||||
1. 写作:
|
要计算文件的哈希值,可以使用哈希算法(如MD5、SHA-1或SHA-256)对文件的内容进行计算。
|
||||||
- 帮助您撰写文章、报告、散文、故事等。
|
|
||||||
- 提供写作建议和技巧。
|
|
||||||
- 协助您进行文案策划和内容创作。
|
|
||||||
|
|
||||||
2. 编程:
|
以下是一个使用sha256算法计算文件哈希值的示例代码:
|
||||||
- 帮助您解决编程问题,提供编程思路和建议。
|
|
||||||
- 协助您编写代码,包括但不限于 Python、Java、C++ 等。
|
|
||||||
- 为您解释复杂的技术概念,让您更容易理解。
|
|
||||||
|
|
||||||
3. 项目支持:
|
```python
|
||||||
- 协助您规划项目进度和任务分配。
|
import hashlib
|
||||||
- 提供项目管理和协作建议。
|
|
||||||
- 在项目实施过程中提供支持,确保项目顺利进行。
|
def calculate_hash(file_path):
|
||||||
|
sha256_hash = hashlib.sha256()
|
||||||
|
with open(file_path, 'rb') as file:
|
||||||
|
for chunk in iter(lambda: file.read(4096), b''):
|
||||||
|
sha256_hash.update(chunk)
|
||||||
|
return sha256_hash.hexdigest()
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
file_path = 'path/to/file.txt'
|
||||||
|
hash_value = calculate_hash(file_path)
|
||||||
|
print('File hash:', hash_value)
|
||||||
|
```
|
||||||
|
|
||||||
|
在上面的示例中,`calculate_hash`函数接受一个文件路径作为参数,并打开文件以二进制读取模式读取文件内容。然后,使用哈希对象sha256初始化,并对文件内容进行分块读取并更新哈希值。最后,通过`hexdigest`方法获取哈希值的十六进制表示。
|
||||||
|
|
||||||
|
可以根据需要更改哈希算法(如使用`hashlib.md5()`来使用MD5算法)和块大小(这里使用4096字节)。
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
md = """
|
||||||
|
要在Ubuntu中将NTFS格式转换为ext4格式,您需要进行以下步骤:
|
||||||
|
|
||||||
|
1. 首先,确保您已经安装了gparted软件。如果没有安装,请使用以下命令进行安装:
|
||||||
|
|
||||||
|
```
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install gparted
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 然后,打开GParted软件。您可以在"应用程序"菜单中搜索并启动它。
|
||||||
|
|
||||||
|
3. 在GParted界面中,选择您想要转换格式的NTFS分区。请小心选择,确保选择正确的分区。
|
||||||
|
|
||||||
|
4. 确保分区未挂载。如果分区当前正在使用,您需要首先卸载它。在命令行中,您可以使用以下命令卸载该分区:
|
||||||
|
|
||||||
|
```
|
||||||
|
sudo umount /dev/sdc1
|
||||||
|
```
|
||||||
|
|
||||||
|
注意:请将"/dev/sdc1"替换为您要卸载的分区的正确路径。
|
||||||
|
|
||||||
|
5. 在GParted界面中,单击菜单中的"设备"选项,然后选择"创建"。
|
||||||
|
|
||||||
|
6. 在弹出的对话框中,选择要转换为的文件系统类型。在这种情况下,选择"ext4"。然后单击"添加"按钮。
|
||||||
|
|
||||||
|
7. 在"操作"菜单中,选择"应用所有操作"。这将开始分区格式转换的过程。
|
||||||
|
|
||||||
|
8. 等待GParted完成转换操作。这可能需要一些时间,具体取决于分区的大小和系统性能。
|
||||||
|
|
||||||
|
9. 转换完成后,您将看到分区的文件系统已更改为ext4。
|
||||||
|
|
||||||
|
10. 最后,请确保挂载分区以便访问它。您可以使用以下命令挂载该分区:
|
||||||
|
|
||||||
|
```
|
||||||
|
sudo mount /dev/sdc1 /media/fuqingxu/eb63a8fa-cee9-48a5-9f05-b1388c3fda9e
|
||||||
|
```
|
||||||
|
|
||||||
|
注意:请将"/dev/sdc1"替换为已转换分区的正确路径,并将"/media/fuqingxu/eb63a8fa-cee9-48a5-9f05-b1388c3fda9e"替换为您要挂载的目标路径。
|
||||||
|
|
||||||
|
请注意,在执行任何分区操作之前,务必备份重要的数据。操作不当可能导致数据丢失。
|
||||||
|
|
||||||
4. 学习辅导:
|
|
||||||
- 帮助您巩固编程基础,提高编程能力。
|
|
||||||
- 提供计算机科学、数据科学、人工智能等相关领域的学习资源和建议。
|
|
||||||
- 解答您在学习过程中遇到的问题,让您更好地掌握知识。
|
|
||||||
|
|
||||||
5. 行业动态和趋势分析:
|
|
||||||
- 为您提供业界最新的新闻和技术趋势。
|
|
||||||
- 分析行业动态,帮助您了解市场发展和竞争态势。
|
|
||||||
- 为您制定技术战略提供参考和建议。
|
|
||||||
|
|
||||||
请随时告诉我您的需求,我会尽力提供帮助。如果您有任何问题或需要解答的议题,请随时提问。
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -43,6 +86,6 @@ validate_path() # validate path so you can run from base directory
|
|||||||
from toolbox import markdown_convertion
|
from toolbox import markdown_convertion
|
||||||
|
|
||||||
html = markdown_convertion(md)
|
html = markdown_convertion(md)
|
||||||
print(html)
|
# print(html)
|
||||||
with open("test.html", "w", encoding="utf-8") as f:
|
with open("test.html", "w", encoding="utf-8") as f:
|
||||||
f.write(html)
|
f.write(html)
|
||||||
|
399
toolbox.py
399
toolbox.py
@ -11,6 +11,7 @@ from functools import wraps
|
|||||||
from shared_utils.config_loader import get_conf
|
from shared_utils.config_loader import get_conf
|
||||||
from shared_utils.config_loader import set_conf
|
from shared_utils.config_loader import set_conf
|
||||||
from shared_utils.advanced_markdown_format import format_io
|
from shared_utils.advanced_markdown_format import format_io
|
||||||
|
from shared_utils.advanced_markdown_format import markdown_convertion
|
||||||
from shared_utils.key_pattern_manager import select_api_key
|
from shared_utils.key_pattern_manager import select_api_key
|
||||||
from shared_utils.key_pattern_manager import is_any_api_key
|
from shared_utils.key_pattern_manager import is_any_api_key
|
||||||
from shared_utils.key_pattern_manager import what_keys
|
from shared_utils.key_pattern_manager import what_keys
|
||||||
@ -20,10 +21,10 @@ from shared_utils.connect_void_terminal import get_plugin_default_kwargs
|
|||||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
from shared_utils.connect_void_terminal import get_chat_default_kwargs
|
||||||
|
|
||||||
pj = os.path.join
|
pj = os.path.join
|
||||||
default_user_name = 'default_user'
|
default_user_name = "default_user"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第一部分
|
第一部分
|
||||||
函数插件输入输出接驳区
|
函数插件输入输出接驳区
|
||||||
- ChatBotWithCookies: 带Cookies的Chatbot类,为实现更多强大的功能做基础
|
- ChatBotWithCookies: 带Cookies的Chatbot类,为实现更多强大的功能做基础
|
||||||
@ -32,7 +33,7 @@ default_user_name = 'default_user'
|
|||||||
- CatchException: 将插件中出的所有问题显示在界面上
|
- CatchException: 将插件中出的所有问题显示在界面上
|
||||||
- HotReload: 实现插件的热更新
|
- HotReload: 实现插件的热更新
|
||||||
- trimmed_format_exc: 打印traceback,为了安全而隐藏绝对地址
|
- trimmed_format_exc: 打印traceback,为了安全而隐藏绝对地址
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -120,22 +121,30 @@ def ArgsGeneralWrapper(f):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def update_ui(chatbot, history, msg='正常', **kwargs): # 刷新界面
|
def update_ui(chatbot, history, msg="正常", **kwargs): # 刷新界面
|
||||||
"""
|
"""
|
||||||
刷新用户界面
|
刷新用户界面
|
||||||
"""
|
"""
|
||||||
assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时, 可用clear将其清空, 然后用for+append循环重新赋值。"
|
assert isinstance(
|
||||||
|
chatbot, ChatBotWithCookies
|
||||||
|
), "在传递chatbot的过程中不要将其丢弃。必要时, 可用clear将其清空, 然后用for+append循环重新赋值。"
|
||||||
cookies = chatbot.get_cookies()
|
cookies = chatbot.get_cookies()
|
||||||
# 备份一份History作为记录
|
# 备份一份History作为记录
|
||||||
cookies.update({'history': history})
|
cookies.update({"history": history})
|
||||||
# 解决插件锁定时的界面显示问题
|
# 解决插件锁定时的界面显示问题
|
||||||
if cookies.get('lock_plugin', None):
|
if cookies.get("lock_plugin", None):
|
||||||
label = cookies.get('llm_model', "") + " | " + "正在锁定插件" + cookies.get('lock_plugin', None)
|
label = (
|
||||||
|
cookies.get("llm_model", "")
|
||||||
|
+ " | "
|
||||||
|
+ "正在锁定插件"
|
||||||
|
+ cookies.get("lock_plugin", None)
|
||||||
|
)
|
||||||
chatbot_gr = gradio.update(value=chatbot, label=label)
|
chatbot_gr = gradio.update(value=chatbot, label=label)
|
||||||
if cookies.get('label', "") != label: cookies['label'] = label # 记住当前的label
|
if cookies.get("label", "") != label:
|
||||||
elif cookies.get('label', None):
|
cookies["label"] = label # 记住当前的label
|
||||||
chatbot_gr = gradio.update(value=chatbot, label=cookies.get('llm_model', ""))
|
elif cookies.get("label", None):
|
||||||
cookies['label'] = None # 清空label
|
chatbot_gr = gradio.update(value=chatbot, label=cookies.get("llm_model", ""))
|
||||||
|
cookies["label"] = None # 清空label
|
||||||
else:
|
else:
|
||||||
chatbot_gr = chatbot
|
chatbot_gr = chatbot
|
||||||
|
|
||||||
@ -146,7 +155,8 @@ def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面
|
|||||||
"""
|
"""
|
||||||
刷新用户界面
|
刷新用户界面
|
||||||
"""
|
"""
|
||||||
if len(chatbot) == 0: chatbot.append(["update_ui_last_msg", lastmsg])
|
if len(chatbot) == 0:
|
||||||
|
chatbot.append(["update_ui_last_msg", lastmsg])
|
||||||
chatbot[-1] = list(chatbot[-1])
|
chatbot[-1] = list(chatbot[-1])
|
||||||
chatbot[-1][-1] = lastmsg
|
chatbot[-1][-1] = lastmsg
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
@ -155,6 +165,7 @@ def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1): # 刷新界面
|
|||||||
|
|
||||||
def trimmed_format_exc():
|
def trimmed_format_exc():
|
||||||
import os, traceback
|
import os, traceback
|
||||||
|
|
||||||
str = traceback.format_exc()
|
str = traceback.format_exc()
|
||||||
current_path = os.getcwd()
|
current_path = os.getcwd()
|
||||||
replace_path = "."
|
replace_path = "."
|
||||||
@ -194,19 +205,21 @@ def HotReload(f):
|
|||||||
最后,使用yield from语句返回重新加载过的函数,并在被装饰的函数上执行。
|
最后,使用yield from语句返回重新加载过的函数,并在被装饰的函数上执行。
|
||||||
最终,装饰器函数返回内部函数。这个内部函数可以将函数的原始定义更新为最新版本,并执行函数的新版本。
|
最终,装饰器函数返回内部函数。这个内部函数可以将函数的原始定义更新为最新版本,并执行函数的新版本。
|
||||||
"""
|
"""
|
||||||
if get_conf('PLUGIN_HOT_RELOAD'):
|
if get_conf("PLUGIN_HOT_RELOAD"):
|
||||||
|
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
fn_name = f.__name__
|
fn_name = f.__name__
|
||||||
f_hot_reload = getattr(importlib.reload(inspect.getmodule(f)), fn_name)
|
f_hot_reload = getattr(importlib.reload(inspect.getmodule(f)), fn_name)
|
||||||
yield from f_hot_reload(*args, **kwargs)
|
yield from f_hot_reload(*args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
else:
|
else:
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第二部分
|
第二部分
|
||||||
其他小工具:
|
其他小工具:
|
||||||
- write_history_to_file: 将结果写入markdown文件中
|
- write_history_to_file: 将结果写入markdown文件中
|
||||||
@ -220,7 +233,7 @@ def HotReload(f):
|
|||||||
- clip_history: 当历史上下文过长时,自动截断
|
- clip_history: 当历史上下文过长时,自动截断
|
||||||
- get_conf: 获取设置
|
- get_conf: 获取设置
|
||||||
- select_api_key: 根据当前的模型类别,抽取可用的api-key
|
- select_api_key: 根据当前的模型类别,抽取可用的api-key
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -239,36 +252,40 @@ def get_reduce_token_percent(text):
|
|||||||
assert ratio > 0 and ratio < 1
|
assert ratio > 0 and ratio < 1
|
||||||
return ratio, str(int(current_tokens - max_limit))
|
return ratio, str(int(current_tokens - max_limit))
|
||||||
except:
|
except:
|
||||||
return 0.5, '不详'
|
return 0.5, "不详"
|
||||||
|
|
||||||
|
|
||||||
def write_history_to_file(history, file_basename=None, file_fullname=None, auto_caption=True):
|
def write_history_to_file(
|
||||||
|
history, file_basename=None, file_fullname=None, auto_caption=True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
|
将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
if file_fullname is None:
|
if file_fullname is None:
|
||||||
if file_basename is not None:
|
if file_basename is not None:
|
||||||
file_fullname = pj(get_log_folder(), file_basename)
|
file_fullname = pj(get_log_folder(), file_basename)
|
||||||
else:
|
else:
|
||||||
file_fullname = pj(get_log_folder(), f'GPT-Academic-{gen_time_str()}.md')
|
file_fullname = pj(get_log_folder(), f"GPT-Academic-{gen_time_str()}.md")
|
||||||
os.makedirs(os.path.dirname(file_fullname), exist_ok=True)
|
os.makedirs(os.path.dirname(file_fullname), exist_ok=True)
|
||||||
with open(file_fullname, 'w', encoding='utf8') as f:
|
with open(file_fullname, "w", encoding="utf8") as f:
|
||||||
f.write('# GPT-Academic Report\n')
|
f.write("# GPT-Academic Report\n")
|
||||||
for i, content in enumerate(history):
|
for i, content in enumerate(history):
|
||||||
try:
|
try:
|
||||||
if type(content) != str: content = str(content)
|
if type(content) != str:
|
||||||
|
content = str(content)
|
||||||
except:
|
except:
|
||||||
continue
|
continue
|
||||||
if i % 2 == 0 and auto_caption:
|
if i % 2 == 0 and auto_caption:
|
||||||
f.write('## ')
|
f.write("## ")
|
||||||
try:
|
try:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
except:
|
except:
|
||||||
# remove everything that cannot be handled by utf8
|
# remove everything that cannot be handled by utf8
|
||||||
f.write(content.encode('utf-8', 'ignore').decode())
|
f.write(content.encode("utf-8", "ignore").decode())
|
||||||
f.write('\n\n')
|
f.write("\n\n")
|
||||||
res = os.path.abspath(file_fullname)
|
res = os.path.abspath(file_fullname)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -277,9 +294,9 @@ def regular_txt_to_markdown(text):
|
|||||||
"""
|
"""
|
||||||
将普通文本转换为Markdown格式的文本。
|
将普通文本转换为Markdown格式的文本。
|
||||||
"""
|
"""
|
||||||
text = text.replace('\n', '\n\n')
|
text = text.replace("\n", "\n\n")
|
||||||
text = text.replace('\n\n\n', '\n\n')
|
text = text.replace("\n\n\n", "\n\n")
|
||||||
text = text.replace('\n\n\n', '\n\n')
|
text = text.replace("\n\n\n", "\n\n")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@ -297,8 +314,9 @@ def find_free_port():
|
|||||||
"""
|
"""
|
||||||
import socket
|
import socket
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
|
||||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||||
s.bind(('', 0))
|
s.bind(("", 0))
|
||||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
return s.getsockname()[1]
|
return s.getsockname()[1]
|
||||||
|
|
||||||
@ -307,45 +325,48 @@ def extract_archive(file_path, dest_dir):
|
|||||||
import zipfile
|
import zipfile
|
||||||
import tarfile
|
import tarfile
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Get the file extension of the input file
|
# Get the file extension of the input file
|
||||||
file_extension = os.path.splitext(file_path)[1]
|
file_extension = os.path.splitext(file_path)[1]
|
||||||
|
|
||||||
# Extract the archive based on its extension
|
# Extract the archive based on its extension
|
||||||
if file_extension == '.zip':
|
if file_extension == ".zip":
|
||||||
with zipfile.ZipFile(file_path, 'r') as zipobj:
|
with zipfile.ZipFile(file_path, "r") as zipobj:
|
||||||
zipobj.extractall(path=dest_dir)
|
zipobj.extractall(path=dest_dir)
|
||||||
print("Successfully extracted zip archive to {}".format(dest_dir))
|
print("Successfully extracted zip archive to {}".format(dest_dir))
|
||||||
|
|
||||||
elif file_extension in ['.tar', '.gz', '.bz2']:
|
elif file_extension in [".tar", ".gz", ".bz2"]:
|
||||||
with tarfile.open(file_path, 'r:*') as tarobj:
|
with tarfile.open(file_path, "r:*") as tarobj:
|
||||||
tarobj.extractall(path=dest_dir)
|
tarobj.extractall(path=dest_dir)
|
||||||
print("Successfully extracted tar archive to {}".format(dest_dir))
|
print("Successfully extracted tar archive to {}".format(dest_dir))
|
||||||
|
|
||||||
# 第三方库,需要预先pip install rarfile
|
# 第三方库,需要预先pip install rarfile
|
||||||
# 此外,Windows上还需要安装winrar软件,配置其Path环境变量,如"C:\Program Files\WinRAR"才可以
|
# 此外,Windows上还需要安装winrar软件,配置其Path环境变量,如"C:\Program Files\WinRAR"才可以
|
||||||
elif file_extension == '.rar':
|
elif file_extension == ".rar":
|
||||||
try:
|
try:
|
||||||
import rarfile
|
import rarfile
|
||||||
|
|
||||||
with rarfile.RarFile(file_path) as rf:
|
with rarfile.RarFile(file_path) as rf:
|
||||||
rf.extractall(path=dest_dir)
|
rf.extractall(path=dest_dir)
|
||||||
print("Successfully extracted rar archive to {}".format(dest_dir))
|
print("Successfully extracted rar archive to {}".format(dest_dir))
|
||||||
except:
|
except:
|
||||||
print("Rar format requires additional dependencies to install")
|
print("Rar format requires additional dependencies to install")
|
||||||
return '\n\n解压失败! 需要安装pip install rarfile来解压rar文件。建议:使用zip压缩格式。'
|
return "\n\n解压失败! 需要安装pip install rarfile来解压rar文件。建议:使用zip压缩格式。"
|
||||||
|
|
||||||
# 第三方库,需要预先pip install py7zr
|
# 第三方库,需要预先pip install py7zr
|
||||||
elif file_extension == '.7z':
|
elif file_extension == ".7z":
|
||||||
try:
|
try:
|
||||||
import py7zr
|
import py7zr
|
||||||
with py7zr.SevenZipFile(file_path, mode='r') as f:
|
|
||||||
|
with py7zr.SevenZipFile(file_path, mode="r") as f:
|
||||||
f.extractall(path=dest_dir)
|
f.extractall(path=dest_dir)
|
||||||
print("Successfully extracted 7z archive to {}".format(dest_dir))
|
print("Successfully extracted 7z archive to {}".format(dest_dir))
|
||||||
except:
|
except:
|
||||||
print("7z format requires additional dependencies to install")
|
print("7z format requires additional dependencies to install")
|
||||||
return '\n\n解压失败! 需要安装pip install py7zr来解压7z文件'
|
return "\n\n解压失败! 需要安装pip install py7zr来解压7z文件"
|
||||||
else:
|
else:
|
||||||
return ''
|
return ""
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def find_recent_files(directory):
|
def find_recent_files(directory):
|
||||||
@ -355,6 +376,7 @@ def find_recent_files(directory):
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
one_minute_ago = current_time - 60
|
one_minute_ago = current_time - 60
|
||||||
recent_files = []
|
recent_files = []
|
||||||
@ -362,7 +384,7 @@ def find_recent_files(directory):
|
|||||||
os.makedirs(directory, exist_ok=True)
|
os.makedirs(directory, exist_ok=True)
|
||||||
for filename in os.listdir(directory):
|
for filename in os.listdir(directory):
|
||||||
file_path = pj(directory, filename)
|
file_path = pj(directory, filename)
|
||||||
if file_path.endswith('.log'):
|
if file_path.endswith(".log"):
|
||||||
continue
|
continue
|
||||||
created_time = os.path.getmtime(file_path)
|
created_time = os.path.getmtime(file_path)
|
||||||
if created_time >= one_minute_ago:
|
if created_time >= one_minute_ago:
|
||||||
@ -388,49 +410,53 @@ def file_already_in_downloadzone(file, user_path):
|
|||||||
def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
|
def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
|
||||||
# 将文件复制一份到下载区
|
# 将文件复制一份到下载区
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
if chatbot is not None:
|
if chatbot is not None:
|
||||||
user_name = get_user(chatbot)
|
user_name = get_user(chatbot)
|
||||||
else:
|
else:
|
||||||
user_name = default_user_name
|
user_name = default_user_name
|
||||||
if not os.path.exists(file):
|
if not os.path.exists(file):
|
||||||
raise FileNotFoundError(f'文件{file}不存在')
|
raise FileNotFoundError(f"文件{file}不存在")
|
||||||
user_path = get_log_folder(user_name, plugin_name=None)
|
user_path = get_log_folder(user_name, plugin_name=None)
|
||||||
if file_already_in_downloadzone(file, user_path):
|
if file_already_in_downloadzone(file, user_path):
|
||||||
new_path = file
|
new_path = file
|
||||||
else:
|
else:
|
||||||
user_path = get_log_folder(user_name, plugin_name='downloadzone')
|
user_path = get_log_folder(user_name, plugin_name="downloadzone")
|
||||||
if rename_file is None: rename_file = f'{gen_time_str()}-{os.path.basename(file)}'
|
if rename_file is None:
|
||||||
|
rename_file = f"{gen_time_str()}-{os.path.basename(file)}"
|
||||||
new_path = pj(user_path, rename_file)
|
new_path = pj(user_path, rename_file)
|
||||||
# 如果已经存在,先删除
|
# 如果已经存在,先删除
|
||||||
if os.path.exists(new_path) and not os.path.samefile(new_path, file): os.remove(new_path)
|
if os.path.exists(new_path) and not os.path.samefile(new_path, file):
|
||||||
|
os.remove(new_path)
|
||||||
# 把文件复制过去
|
# 把文件复制过去
|
||||||
if not os.path.exists(new_path): shutil.copyfile(file, new_path)
|
if not os.path.exists(new_path):
|
||||||
|
shutil.copyfile(file, new_path)
|
||||||
# 将文件添加到chatbot cookie中
|
# 将文件添加到chatbot cookie中
|
||||||
if chatbot is not None:
|
if chatbot is not None:
|
||||||
if 'files_to_promote' in chatbot._cookies:
|
if "files_to_promote" in chatbot._cookies:
|
||||||
current = chatbot._cookies['files_to_promote']
|
current = chatbot._cookies["files_to_promote"]
|
||||||
else:
|
else:
|
||||||
current = []
|
current = []
|
||||||
if new_path not in current: # 避免把同一个文件添加多次
|
if new_path not in current: # 避免把同一个文件添加多次
|
||||||
chatbot._cookies.update({'files_to_promote': [new_path] + current})
|
chatbot._cookies.update({"files_to_promote": [new_path] + current})
|
||||||
return new_path
|
return new_path
|
||||||
|
|
||||||
|
|
||||||
def disable_auto_promotion(chatbot):
|
def disable_auto_promotion(chatbot):
|
||||||
chatbot._cookies.update({'files_to_promote': []})
|
chatbot._cookies.update({"files_to_promote": []})
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def del_outdated_uploads(outdate_time_seconds, target_path_base=None):
|
def del_outdated_uploads(outdate_time_seconds, target_path_base=None):
|
||||||
if target_path_base is None:
|
if target_path_base is None:
|
||||||
user_upload_dir = get_conf('PATH_PRIVATE_UPLOAD')
|
user_upload_dir = get_conf("PATH_PRIVATE_UPLOAD")
|
||||||
else:
|
else:
|
||||||
user_upload_dir = target_path_base
|
user_upload_dir = target_path_base
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
one_hour_ago = current_time - outdate_time_seconds
|
one_hour_ago = current_time - outdate_time_seconds
|
||||||
# Get a list of all subdirectories in the user_upload_dir folder
|
# Get a list of all subdirectories in the user_upload_dir folder
|
||||||
# Remove subdirectories that are older than one hour
|
# Remove subdirectories that are older than one hour
|
||||||
for subdirectory in glob.glob(f'{user_upload_dir}/*'):
|
for subdirectory in glob.glob(f"{user_upload_dir}/*"):
|
||||||
subdirectory_time = os.path.getmtime(subdirectory)
|
subdirectory_time = os.path.getmtime(subdirectory)
|
||||||
if subdirectory_time < one_hour_ago:
|
if subdirectory_time < one_hour_ago:
|
||||||
try:
|
try:
|
||||||
@ -447,8 +473,8 @@ def html_local_file(file):
|
|||||||
return file
|
return file
|
||||||
|
|
||||||
|
|
||||||
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
|
def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
|
||||||
style = ''
|
style = ""
|
||||||
if max_width is not None:
|
if max_width is not None:
|
||||||
style += f"max-width: {max_width};"
|
style += f"max-width: {max_width};"
|
||||||
if max_height is not None:
|
if max_height is not None:
|
||||||
@ -456,20 +482,23 @@ def html_local_img(__file, layout='left', max_width=None, max_height=None, md=Tr
|
|||||||
__file = html_local_file(__file)
|
__file = html_local_file(__file)
|
||||||
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
|
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
|
||||||
if md:
|
if md:
|
||||||
a = f''
|
a = f""
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
|
||||||
def file_manifest_filter_type(file_list, filter_: list = None):
|
def file_manifest_filter_type(file_list, filter_: list = None):
|
||||||
new_list = []
|
new_list = []
|
||||||
if not filter_: filter_ = ['png', 'jpg', 'jpeg']
|
if not filter_:
|
||||||
|
filter_ = ["png", "jpg", "jpeg"]
|
||||||
for file in file_list:
|
for file in file_list:
|
||||||
if str(os.path.basename(file)).split('.')[-1] in filter_:
|
if str(os.path.basename(file)).split(".")[-1] in filter_:
|
||||||
new_list.append(html_local_img(file, md=False))
|
new_list.append(html_local_img(file, md=False))
|
||||||
else:
|
else:
|
||||||
new_list.append(file)
|
new_list.append(file)
|
||||||
return new_list
|
return new_list
|
||||||
|
|
||||||
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
|
|
||||||
|
def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
head: 表头:[]
|
head: 表头:[]
|
||||||
@ -487,17 +516,20 @@ def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
|
|||||||
max_len = max(len(column) for column in transposed_tabs)
|
max_len = max(len(column) for column in transposed_tabs)
|
||||||
|
|
||||||
tab_format = "| %s "
|
tab_format = "| %s "
|
||||||
tabs_list = "".join([tab_format % i for i in head]) + '|\n'
|
tabs_list = "".join([tab_format % i for i in head]) + "|\n"
|
||||||
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
|
tabs_list += "".join([tab_format % alignment for i in head]) + "|\n"
|
||||||
|
|
||||||
for i in range(max_len):
|
for i in range(max_len):
|
||||||
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
|
row_data = [tab[i] if i < len(tab) else "" for tab in transposed_tabs]
|
||||||
row_data = file_manifest_filter_type(row_data, filter_=None)
|
row_data = file_manifest_filter_type(row_data, filter_=None)
|
||||||
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
|
tabs_list += "".join([tab_format % i for i in row_data]) + "|\n"
|
||||||
|
|
||||||
return tabs_list
|
return tabs_list
|
||||||
|
|
||||||
def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkboxes, cookies):
|
|
||||||
|
def on_file_uploaded(
|
||||||
|
request: gradio.Request, files, chatbot, txt, txt2, checkboxes, cookies
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
当文件被上传时的回调函数
|
当文件被上传时的回调函数
|
||||||
"""
|
"""
|
||||||
@ -515,94 +547,118 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
|
|||||||
del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name))
|
del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name))
|
||||||
|
|
||||||
# 逐个文件转移到目标路径
|
# 逐个文件转移到目标路径
|
||||||
upload_msg = ''
|
upload_msg = ""
|
||||||
for file in files:
|
for file in files:
|
||||||
file_origin_name = os.path.basename(file.orig_name)
|
file_origin_name = os.path.basename(file.orig_name)
|
||||||
this_file_path = pj(target_path_base, file_origin_name)
|
this_file_path = pj(target_path_base, file_origin_name)
|
||||||
shutil.move(file.name, this_file_path)
|
shutil.move(file.name, this_file_path)
|
||||||
upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path + '.extract')
|
upload_msg += extract_archive(
|
||||||
|
file_path=this_file_path, dest_dir=this_file_path + ".extract"
|
||||||
|
)
|
||||||
|
|
||||||
# 整理文件集合 输出消息
|
# 整理文件集合 输出消息
|
||||||
moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
|
moved_files = [fp for fp in glob.glob(f"{target_path_base}/**/*", recursive=True)]
|
||||||
moved_files_str = to_markdown_tabs(head=['文件'], tabs=[moved_files])
|
moved_files_str = to_markdown_tabs(head=["文件"], tabs=[moved_files])
|
||||||
chatbot.append(['我上传了文件,请查收',
|
chatbot.append(
|
||||||
f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
|
[
|
||||||
f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
|
"我上传了文件,请查收",
|
||||||
f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数' + upload_msg])
|
f"[Local Message] 收到以下文件: \n\n{moved_files_str}"
|
||||||
|
+ f"\n\n调用路径参数已自动修正到: \n\n{txt}"
|
||||||
|
+ f"\n\n现在您点击任意函数插件时,以上文件将被作为输入参数"
|
||||||
|
+ upload_msg,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
txt, txt2 = target_path_base, ""
|
txt, txt2 = target_path_base, ""
|
||||||
if "浮动输入区" in checkboxes:
|
if "浮动输入区" in checkboxes:
|
||||||
txt, txt2 = txt2, txt
|
txt, txt2 = txt2, txt
|
||||||
|
|
||||||
# 记录近期文件
|
# 记录近期文件
|
||||||
cookies.update({
|
cookies.update(
|
||||||
'most_recent_uploaded': {
|
{
|
||||||
'path': target_path_base,
|
"most_recent_uploaded": {
|
||||||
'time': time.time(),
|
"path": target_path_base,
|
||||||
'time_str': time_tag
|
"time": time.time(),
|
||||||
}})
|
"time_str": time_tag,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
return chatbot, txt, txt2, cookies
|
return chatbot, txt, txt2, cookies
|
||||||
|
|
||||||
|
|
||||||
def on_report_generated(cookies, files, chatbot):
|
def on_report_generated(cookies, files, chatbot):
|
||||||
# from toolbox import find_recent_files
|
# from toolbox import find_recent_files
|
||||||
# PATH_LOGGING = get_conf('PATH_LOGGING')
|
# PATH_LOGGING = get_conf('PATH_LOGGING')
|
||||||
if 'files_to_promote' in cookies:
|
if "files_to_promote" in cookies:
|
||||||
report_files = cookies['files_to_promote']
|
report_files = cookies["files_to_promote"]
|
||||||
cookies.pop('files_to_promote')
|
cookies.pop("files_to_promote")
|
||||||
else:
|
else:
|
||||||
report_files = []
|
report_files = []
|
||||||
# report_files = find_recent_files(PATH_LOGGING)
|
# report_files = find_recent_files(PATH_LOGGING)
|
||||||
if len(report_files) == 0:
|
if len(report_files) == 0:
|
||||||
return cookies, None, chatbot
|
return cookies, None, chatbot
|
||||||
# files.extend(report_files)
|
# files.extend(report_files)
|
||||||
file_links = ''
|
file_links = ""
|
||||||
for f in report_files: file_links += f'<br/><a href="file={os.path.abspath(f)}" target="_blank">{f}</a>'
|
for f in report_files:
|
||||||
chatbot.append(['报告如何远程获取?', f'报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}'])
|
file_links += (
|
||||||
|
f'<br/><a href="file={os.path.abspath(f)}" target="_blank">{f}</a>'
|
||||||
|
)
|
||||||
|
chatbot.append(["报告如何远程获取?", f"报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}"])
|
||||||
return cookies, report_files, chatbot
|
return cookies, report_files, chatbot
|
||||||
|
|
||||||
|
|
||||||
def load_chat_cookies():
|
def load_chat_cookies():
|
||||||
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
|
API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf(
|
||||||
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN')
|
"API_KEY", "LLM_MODEL", "AZURE_API_KEY"
|
||||||
|
)
|
||||||
|
AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf(
|
||||||
|
"AZURE_CFG_ARRAY", "NUM_CUSTOM_BASIC_BTN"
|
||||||
|
)
|
||||||
|
|
||||||
# deal with azure openai key
|
# deal with azure openai key
|
||||||
if is_any_api_key(AZURE_API_KEY):
|
if is_any_api_key(AZURE_API_KEY):
|
||||||
if is_any_api_key(API_KEY):
|
if is_any_api_key(API_KEY):
|
||||||
API_KEY = API_KEY + ',' + AZURE_API_KEY
|
API_KEY = API_KEY + "," + AZURE_API_KEY
|
||||||
else:
|
else:
|
||||||
API_KEY = AZURE_API_KEY
|
API_KEY = AZURE_API_KEY
|
||||||
if len(AZURE_CFG_ARRAY) > 0:
|
if len(AZURE_CFG_ARRAY) > 0:
|
||||||
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
|
for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
|
||||||
if not azure_model_name.startswith('azure'):
|
if not azure_model_name.startswith("azure"):
|
||||||
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
|
raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
|
||||||
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
|
AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
|
||||||
if is_any_api_key(AZURE_API_KEY_):
|
if is_any_api_key(AZURE_API_KEY_):
|
||||||
if is_any_api_key(API_KEY):
|
if is_any_api_key(API_KEY):
|
||||||
API_KEY = API_KEY + ',' + AZURE_API_KEY_
|
API_KEY = API_KEY + "," + AZURE_API_KEY_
|
||||||
else:
|
else:
|
||||||
API_KEY = AZURE_API_KEY_
|
API_KEY = AZURE_API_KEY_
|
||||||
|
|
||||||
customize_fn_overwrite_ = {}
|
customize_fn_overwrite_ = {}
|
||||||
for k in range(NUM_CUSTOM_BASIC_BTN):
|
for k in range(NUM_CUSTOM_BASIC_BTN):
|
||||||
customize_fn_overwrite_.update({
|
customize_fn_overwrite_.update(
|
||||||
"自定义按钮" + str(k+1):{
|
{
|
||||||
|
"自定义按钮"
|
||||||
|
+ str(k + 1): {
|
||||||
"Title": r"",
|
"Title": r"",
|
||||||
"Prefix": r"请在自定义菜单中定义提示词前缀.",
|
"Prefix": r"请在自定义菜单中定义提示词前缀.",
|
||||||
"Suffix": r"请在自定义菜单中定义提示词后缀",
|
"Suffix": r"请在自定义菜单中定义提示词后缀",
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_}
|
)
|
||||||
|
return {
|
||||||
|
"api_key": API_KEY,
|
||||||
|
"llm_model": LLM_MODEL,
|
||||||
|
"customize_fn_overwrite": customize_fn_overwrite_,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def clear_line_break(txt):
|
def clear_line_break(txt):
|
||||||
txt = txt.replace('\n', ' ')
|
txt = txt.replace("\n", " ")
|
||||||
txt = txt.replace(' ', ' ')
|
txt = txt.replace(" ", " ")
|
||||||
txt = txt.replace(' ', ' ')
|
txt = txt.replace(" ", " ")
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
|
|
||||||
class DummyWith():
|
class DummyWith:
|
||||||
"""
|
"""
|
||||||
这段代码定义了一个名为DummyWith的空上下文管理器,
|
这段代码定义了一个名为DummyWith的空上下文管理器,
|
||||||
它的作用是……额……就是不起作用,即在代码结构不变得情况下取代其他的上下文管理器。
|
它的作用是……额……就是不起作用,即在代码结构不变得情况下取代其他的上下文管理器。
|
||||||
@ -626,32 +682,45 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def is_path_legal(path: str) -> bool:
|
def is_path_legal(path: str) -> bool:
|
||||||
'''
|
"""
|
||||||
check path for sub url
|
check path for sub url
|
||||||
path: path to check
|
path: path to check
|
||||||
return value: do sub url wrap
|
return value: do sub url wrap
|
||||||
'''
|
"""
|
||||||
if path == "/": return True
|
if path == "/":
|
||||||
|
return True
|
||||||
if len(path) == 0:
|
if len(path) == 0:
|
||||||
print("ilegal custom path: {}\npath must not be empty\ndeploy on root url".format(path))
|
print(
|
||||||
|
"ilegal custom path: {}\npath must not be empty\ndeploy on root url".format(
|
||||||
|
path
|
||||||
|
)
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
if path[0] == '/':
|
if path[0] == "/":
|
||||||
if path[1] != '/':
|
if path[1] != "/":
|
||||||
print("deploy on sub-path {}".format(path))
|
print("deploy on sub-path {}".format(path))
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
print("ilegal custom path: {}\npath should begin with \'/\'\ndeploy on root url".format(path))
|
print(
|
||||||
|
"ilegal custom path: {}\npath should begin with '/'\ndeploy on root url".format(
|
||||||
|
path
|
||||||
|
)
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not is_path_legal(custom_path): raise RuntimeError('Ilegal custom path')
|
if not is_path_legal(custom_path):
|
||||||
|
raise RuntimeError("Ilegal custom path")
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
if custom_path != "/":
|
if custom_path != "/":
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def read_main():
|
def read_main():
|
||||||
return {"message": f"Gradio is running at: {custom_path}"}
|
return {"message": f"Gradio is running at: {custom_path}"}
|
||||||
|
|
||||||
app = gr.mount_gradio_app(app, demo, path=custom_path)
|
app = gr.mount_gradio_app(app, demo, path=custom_path)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
|
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
|
||||||
|
|
||||||
@ -667,13 +736,18 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
|
|||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from request_llms.bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
|
|
||||||
def get_token_num(txt):
|
def get_token_num(txt):
|
||||||
return len(tokenizer.encode(txt, disallowed_special=()))
|
return len(tokenizer.encode(txt, disallowed_special=()))
|
||||||
|
|
||||||
input_token_num = get_token_num(inputs)
|
input_token_num = get_token_num(inputs)
|
||||||
|
|
||||||
if max_token_limit < 5000: output_token_expect = 256 # 4k & 2k models
|
if max_token_limit < 5000:
|
||||||
elif max_token_limit < 9000: output_token_expect = 512 # 8k models
|
output_token_expect = 256 # 4k & 2k models
|
||||||
else: output_token_expect = 1024 # 16k & 32k models
|
elif max_token_limit < 9000:
|
||||||
|
output_token_expect = 512 # 8k models
|
||||||
|
else:
|
||||||
|
output_token_expect = 1024 # 16k & 32k models
|
||||||
|
|
||||||
if input_token_num < max_token_limit * 3 / 4:
|
if input_token_num < max_token_limit * 3 / 4:
|
||||||
# 当输入部分的token占比小于限制的3/4时,裁剪时
|
# 当输入部分的token占比小于限制的3/4时,裁剪时
|
||||||
@ -690,9 +764,9 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
|
|||||||
history = []
|
history = []
|
||||||
return history
|
return history
|
||||||
|
|
||||||
everything = ['']
|
everything = [""]
|
||||||
everything.extend(history)
|
everything.extend(history)
|
||||||
n_token = get_token_num('\n'.join(everything))
|
n_token = get_token_num("\n".join(everything))
|
||||||
everything_token = [get_token_num(e) for e in everything]
|
everything_token = [get_token_num(e) for e in everything]
|
||||||
|
|
||||||
# 截断时的颗粒度
|
# 截断时的颗粒度
|
||||||
@ -702,29 +776,32 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
|
|||||||
where = np.argmax(everything_token)
|
where = np.argmax(everything_token)
|
||||||
encoded = tokenizer.encode(everything[where], disallowed_special=())
|
encoded = tokenizer.encode(everything[where], disallowed_special=())
|
||||||
clipped_encoded = encoded[: len(encoded) - delta]
|
clipped_encoded = encoded[: len(encoded) - delta]
|
||||||
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
|
everything[where] = tokenizer.decode(clipped_encoded)[
|
||||||
|
:-1
|
||||||
|
] # -1 to remove the may-be illegal char
|
||||||
everything_token[where] = get_token_num(everything[where])
|
everything_token[where] = get_token_num(everything[where])
|
||||||
n_token = get_token_num('\n'.join(everything))
|
n_token = get_token_num("\n".join(everything))
|
||||||
|
|
||||||
history = everything[1:]
|
history = everything[1:]
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
第三部分
|
第三部分
|
||||||
其他小工具:
|
其他小工具:
|
||||||
- zip_folder: 把某个路径下所有文件压缩,然后转移到指定的另一个路径中(gpt写的)
|
- zip_folder: 把某个路径下所有文件压缩,然后转移到指定的另一个路径中(gpt写的)
|
||||||
- gen_time_str: 生成时间戳
|
- gen_time_str: 生成时间戳
|
||||||
- ProxyNetworkActivate: 临时地启动代理网络(如果有)
|
- ProxyNetworkActivate: 临时地启动代理网络(如果有)
|
||||||
- objdump/objload: 快捷的调试函数
|
- objdump/objload: 快捷的调试函数
|
||||||
========================================================================
|
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def zip_folder(source_folder, dest_folder, zip_name):
|
def zip_folder(source_folder, dest_folder, zip_name):
|
||||||
import zipfile
|
import zipfile
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Make sure the source folder exists
|
# Make sure the source folder exists
|
||||||
if not os.path.exists(source_folder):
|
if not os.path.exists(source_folder):
|
||||||
print(f"{source_folder} does not exist")
|
print(f"{source_folder} does not exist")
|
||||||
@ -739,7 +816,7 @@ def zip_folder(source_folder, dest_folder, zip_name):
|
|||||||
zip_file = pj(dest_folder, zip_name)
|
zip_file = pj(dest_folder, zip_name)
|
||||||
|
|
||||||
# Create a ZipFile object
|
# Create a ZipFile object
|
||||||
with zipfile.ZipFile(zip_file, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
with zipfile.ZipFile(zip_file, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||||
# Walk through the source folder and add files to the zip file
|
# Walk through the source folder and add files to the zip file
|
||||||
for foldername, subfolders, filenames in os.walk(source_folder):
|
for foldername, subfolders, filenames in os.walk(source_folder):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
@ -756,29 +833,33 @@ def zip_folder(source_folder, dest_folder, zip_name):
|
|||||||
|
|
||||||
def zip_result(folder):
|
def zip_result(folder):
|
||||||
t = gen_time_str()
|
t = gen_time_str()
|
||||||
zip_folder(folder, get_log_folder(), f'{t}-result.zip')
|
zip_folder(folder, get_log_folder(), f"{t}-result.zip")
|
||||||
return pj(get_log_folder(), f'{t}-result.zip')
|
return pj(get_log_folder(), f"{t}-result.zip")
|
||||||
|
|
||||||
|
|
||||||
def gen_time_str():
|
def gen_time_str():
|
||||||
import time
|
import time
|
||||||
|
|
||||||
return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
||||||
|
|
||||||
|
|
||||||
def get_log_folder(user=default_user_name, plugin_name='shared'):
|
def get_log_folder(user=default_user_name, plugin_name="shared"):
|
||||||
if user is None: user = default_user_name
|
if user is None:
|
||||||
PATH_LOGGING = get_conf('PATH_LOGGING')
|
user = default_user_name
|
||||||
|
PATH_LOGGING = get_conf("PATH_LOGGING")
|
||||||
if plugin_name is None:
|
if plugin_name is None:
|
||||||
_dir = pj(PATH_LOGGING, user)
|
_dir = pj(PATH_LOGGING, user)
|
||||||
else:
|
else:
|
||||||
_dir = pj(PATH_LOGGING, user, plugin_name)
|
_dir = pj(PATH_LOGGING, user, plugin_name)
|
||||||
if not os.path.exists(_dir): os.makedirs(_dir)
|
if not os.path.exists(_dir):
|
||||||
|
os.makedirs(_dir)
|
||||||
return _dir
|
return _dir
|
||||||
|
|
||||||
|
|
||||||
def get_upload_folder(user=default_user_name, tag=None):
|
def get_upload_folder(user=default_user_name, tag=None):
|
||||||
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
|
PATH_PRIVATE_UPLOAD = get_conf("PATH_PRIVATE_UPLOAD")
|
||||||
if user is None: user = default_user_name
|
if user is None:
|
||||||
|
user = default_user_name
|
||||||
if tag is None or len(tag) == 0:
|
if tag is None or len(tag) == 0:
|
||||||
target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
|
target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
|
||||||
else:
|
else:
|
||||||
@ -787,9 +868,9 @@ def get_upload_folder(user=default_user_name, tag=None):
|
|||||||
|
|
||||||
|
|
||||||
def is_the_upload_folder(string):
|
def is_the_upload_folder(string):
|
||||||
PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
|
PATH_PRIVATE_UPLOAD = get_conf("PATH_PRIVATE_UPLOAD")
|
||||||
pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$'
|
pattern = r"^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$"
|
||||||
pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD)
|
pattern = pattern.replace("PATH_PRIVATE_UPLOAD", PATH_PRIVATE_UPLOAD)
|
||||||
if re.match(pattern, string):
|
if re.match(pattern, string):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
@ -797,10 +878,10 @@ def is_the_upload_folder(string):
|
|||||||
|
|
||||||
|
|
||||||
def get_user(chatbotwithcookies):
|
def get_user(chatbotwithcookies):
|
||||||
return chatbotwithcookies._cookies.get('user_name', default_user_name)
|
return chatbotwithcookies._cookies.get("user_name", default_user_name)
|
||||||
|
|
||||||
|
|
||||||
class ProxyNetworkActivate():
|
class ProxyNetworkActivate:
|
||||||
"""
|
"""
|
||||||
这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
|
这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
|
||||||
"""
|
"""
|
||||||
@ -813,38 +894,48 @@ class ProxyNetworkActivate():
|
|||||||
else:
|
else:
|
||||||
# 给定了task, 我们检查一下
|
# 给定了task, 我们检查一下
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
WHEN_TO_USE_PROXY = get_conf('WHEN_TO_USE_PROXY')
|
|
||||||
self.valid = (task in WHEN_TO_USE_PROXY)
|
WHEN_TO_USE_PROXY = get_conf("WHEN_TO_USE_PROXY")
|
||||||
|
self.valid = task in WHEN_TO_USE_PROXY
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
if not self.valid: return self
|
if not self.valid:
|
||||||
|
return self
|
||||||
from toolbox import get_conf
|
from toolbox import get_conf
|
||||||
proxies = get_conf('proxies')
|
|
||||||
if 'no_proxy' in os.environ: os.environ.pop('no_proxy')
|
proxies = get_conf("proxies")
|
||||||
|
if "no_proxy" in os.environ:
|
||||||
|
os.environ.pop("no_proxy")
|
||||||
if proxies is not None:
|
if proxies is not None:
|
||||||
if 'http' in proxies: os.environ['HTTP_PROXY'] = proxies['http']
|
if "http" in proxies:
|
||||||
if 'https' in proxies: os.environ['HTTPS_PROXY'] = proxies['https']
|
os.environ["HTTP_PROXY"] = proxies["http"]
|
||||||
|
if "https" in proxies:
|
||||||
|
os.environ["HTTPS_PROXY"] = proxies["https"]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
os.environ['no_proxy'] = '*'
|
os.environ["no_proxy"] = "*"
|
||||||
if 'HTTP_PROXY' in os.environ: os.environ.pop('HTTP_PROXY')
|
if "HTTP_PROXY" in os.environ:
|
||||||
if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY')
|
os.environ.pop("HTTP_PROXY")
|
||||||
|
if "HTTPS_PROXY" in os.environ:
|
||||||
|
os.environ.pop("HTTPS_PROXY")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def objdump(obj, file='objdump.tmp'):
|
def objdump(obj, file="objdump.tmp"):
|
||||||
import pickle
|
import pickle
|
||||||
with open(file, 'wb+') as f:
|
|
||||||
|
with open(file, "wb+") as f:
|
||||||
pickle.dump(obj, f)
|
pickle.dump(obj, f)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def objload(file='objdump.tmp'):
|
def objload(file="objdump.tmp"):
|
||||||
import pickle, os
|
import pickle, os
|
||||||
|
|
||||||
if not os.path.exists(file):
|
if not os.path.exists(file):
|
||||||
return
|
return
|
||||||
with open(file, 'rb') as f:
|
with open(file, "rb") as f:
|
||||||
return pickle.load(f)
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
@ -863,22 +954,25 @@ def Singleton(cls):
|
|||||||
|
|
||||||
|
|
||||||
def get_pictures_list(path):
|
def get_pictures_list(path):
|
||||||
file_manifest = [f for f in glob.glob(f'{path}/**/*.jpg', recursive=True)]
|
file_manifest = [f for f in glob.glob(f"{path}/**/*.jpg", recursive=True)]
|
||||||
file_manifest += [f for f in glob.glob(f'{path}/**/*.jpeg', recursive=True)]
|
file_manifest += [f for f in glob.glob(f"{path}/**/*.jpeg", recursive=True)]
|
||||||
file_manifest += [f for f in glob.glob(f'{path}/**/*.png', recursive=True)]
|
file_manifest += [f for f in glob.glob(f"{path}/**/*.png", recursive=True)]
|
||||||
return file_manifest
|
return file_manifest
|
||||||
|
|
||||||
|
|
||||||
def have_any_recent_upload_image_files(chatbot):
|
def have_any_recent_upload_image_files(chatbot):
|
||||||
_5min = 5 * 60
|
_5min = 5 * 60
|
||||||
if chatbot is None: return False, None # chatbot is None
|
if chatbot is None:
|
||||||
|
return False, None # chatbot is None
|
||||||
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
|
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
|
||||||
if not most_recent_uploaded: return False, None # most_recent_uploaded is None
|
if not most_recent_uploaded:
|
||||||
|
return False, None # most_recent_uploaded is None
|
||||||
if time.time() - most_recent_uploaded["time"] < _5min:
|
if time.time() - most_recent_uploaded["time"] < _5min:
|
||||||
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
|
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
|
||||||
path = most_recent_uploaded['path']
|
path = most_recent_uploaded["path"]
|
||||||
file_manifest = get_pictures_list(path)
|
file_manifest = get_pictures_list(path)
|
||||||
if len(file_manifest) == 0: return False, None
|
if len(file_manifest) == 0:
|
||||||
|
return False, None
|
||||||
return True, file_manifest # most_recent_uploaded is new
|
return True, file_manifest # most_recent_uploaded is new
|
||||||
else:
|
else:
|
||||||
return False, None # most_recent_uploaded is too old
|
return False, None # most_recent_uploaded is too old
|
||||||
@ -887,16 +981,19 @@ def have_any_recent_upload_image_files(chatbot):
|
|||||||
# Function to encode the image
|
# Function to encode the image
|
||||||
def encode_image(image_path):
|
def encode_image(image_path):
|
||||||
with open(image_path, "rb") as image_file:
|
with open(image_path, "rb") as image_file:
|
||||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def get_max_token(llm_kwargs):
|
def get_max_token(llm_kwargs):
|
||||||
from request_llms.bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
return model_info[llm_kwargs['llm_model']]['max_token']
|
|
||||||
|
return model_info[llm_kwargs["llm_model"]]["max_token"]
|
||||||
|
|
||||||
|
|
||||||
def check_packages(packages=[]):
|
def check_packages(packages=[]):
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
for p in packages:
|
for p in packages:
|
||||||
spam_spec = importlib.util.find_spec(p)
|
spam_spec = importlib.util.find_spec(p)
|
||||||
if spam_spec is None: raise ModuleNotFoundError
|
if spam_spec is None:
|
||||||
|
raise ModuleNotFoundError
|
||||||
|
Loading…
x
Reference in New Issue
Block a user