format source code
This commit is contained in:
		
							parent
							
								
									1714116a89
								
							
						
					
					
						commit
						3d4c6f54f1
					
				@ -5,7 +5,7 @@ import glob, os, requests, time
 | 
			
		||||
pj = os.path.join
 | 
			
		||||
ARXIV_CACHE_DIR = os.path.expanduser(f"~/arxiv_cache/")
 | 
			
		||||
 | 
			
		||||
# =================================== 工具函数 ===============================================
 | 
			
		||||
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- 工具函数 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
# 专业词汇声明  = 'If the term "agent" is used in this section, it should be translated to "智能体". '
 | 
			
		||||
def switch_prompt(pfg, mode, more_requirement):
 | 
			
		||||
    """
 | 
			
		||||
@ -142,7 +142,7 @@ def arxiv_download(chatbot, history, txt, allow_cache=True):
 | 
			
		||||
    from toolbox import extract_archive
 | 
			
		||||
    extract_archive(file_path=dst, dest_dir=extract_dst)
 | 
			
		||||
    return extract_dst, arxiv_id
 | 
			
		||||
# ========================================= 插件主程序1 =====================================================    
 | 
			
		||||
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序1 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=    
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CatchException
 | 
			
		||||
@ -218,7 +218,7 @@ def Latex英文纠错加PDF对比(txt, llm_kwargs, plugin_kwargs, chatbot, histo
 | 
			
		||||
    # <-------------- we are done ------------->
 | 
			
		||||
    return success
 | 
			
		||||
 | 
			
		||||
# ========================================= 插件主程序2 =====================================================    
 | 
			
		||||
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序2 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=    
 | 
			
		||||
 | 
			
		||||
@CatchException
 | 
			
		||||
def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
 | 
			
		||||
 | 
			
		||||
@ -1,15 +1,18 @@
 | 
			
		||||
import os, shutil
 | 
			
		||||
import re
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
PRESERVE = 0
 | 
			
		||||
TRANSFORM = 1
 | 
			
		||||
 | 
			
		||||
pj = os.path.join
 | 
			
		||||
 | 
			
		||||
class LinkedListNode():
 | 
			
		||||
 | 
			
		||||
class LinkedListNode:
 | 
			
		||||
    """
 | 
			
		||||
    Linked List Node
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, string, preserve=True) -> None:
 | 
			
		||||
        self.string = string
 | 
			
		||||
        self.preserve = preserve
 | 
			
		||||
@ -18,19 +21,22 @@ class LinkedListNode():
 | 
			
		||||
        # self.begin_line = 0
 | 
			
		||||
        # self.begin_char = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_to_linklist(text, mask):
 | 
			
		||||
    root = LinkedListNode("", preserve=True)
 | 
			
		||||
    current_node = root
 | 
			
		||||
    for c, m, i in zip(text, mask, range(len(text))):
 | 
			
		||||
        if (m==PRESERVE and current_node.preserve) \
 | 
			
		||||
            or (m==TRANSFORM and not current_node.preserve):
 | 
			
		||||
        if (m == PRESERVE and current_node.preserve) or (
 | 
			
		||||
            m == TRANSFORM and not current_node.preserve
 | 
			
		||||
        ):
 | 
			
		||||
            # add
 | 
			
		||||
            current_node.string += c
 | 
			
		||||
        else:
 | 
			
		||||
            current_node.next = LinkedListNode(c, preserve=(m==PRESERVE))
 | 
			
		||||
            current_node.next = LinkedListNode(c, preserve=(m == PRESERVE))
 | 
			
		||||
            current_node = current_node.next
 | 
			
		||||
    return root
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def post_process(root):
 | 
			
		||||
    # 修复括号
 | 
			
		||||
    node = root
 | 
			
		||||
@ -38,21 +44,24 @@ def post_process(root):
 | 
			
		||||
        string = node.string
 | 
			
		||||
        if node.preserve:
 | 
			
		||||
            node = node.next
 | 
			
		||||
            if node is None: break
 | 
			
		||||
            if node is None:
 | 
			
		||||
                break
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        def break_check(string):
 | 
			
		||||
            str_stack = [""] # (lv, index)
 | 
			
		||||
            str_stack = [""]  # (lv, index)
 | 
			
		||||
            for i, c in enumerate(string):
 | 
			
		||||
                if c == '{':
 | 
			
		||||
                    str_stack.append('{')
 | 
			
		||||
                elif c == '}':
 | 
			
		||||
                if c == "{":
 | 
			
		||||
                    str_stack.append("{")
 | 
			
		||||
                elif c == "}":
 | 
			
		||||
                    if len(str_stack) == 1:
 | 
			
		||||
                        print('stack fix')
 | 
			
		||||
                        print("stack fix")
 | 
			
		||||
                        return i
 | 
			
		||||
                    str_stack.pop(-1)
 | 
			
		||||
                else:
 | 
			
		||||
                    str_stack[-1] += c
 | 
			
		||||
            return -1
 | 
			
		||||
 | 
			
		||||
        bp = break_check(string)
 | 
			
		||||
 | 
			
		||||
        if bp == -1:
 | 
			
		||||
@ -69,51 +78,66 @@ def post_process(root):
 | 
			
		||||
            node.next = q
 | 
			
		||||
 | 
			
		||||
        node = node.next
 | 
			
		||||
        if node is None: break
 | 
			
		||||
        if node is None:
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
    # 屏蔽空行和太短的句子
 | 
			
		||||
    node = root
 | 
			
		||||
    while True:
 | 
			
		||||
        if len(node.string.strip('\n').strip(''))==0: node.preserve = True
 | 
			
		||||
        if len(node.string.strip('\n').strip(''))<42: node.preserve = True
 | 
			
		||||
        if len(node.string.strip("\n").strip("")) == 0:
 | 
			
		||||
            node.preserve = True
 | 
			
		||||
        if len(node.string.strip("\n").strip("")) < 42:
 | 
			
		||||
            node.preserve = True
 | 
			
		||||
        node = node.next
 | 
			
		||||
        if node is None: break
 | 
			
		||||
        if node is None:
 | 
			
		||||
            break
 | 
			
		||||
    node = root
 | 
			
		||||
    while True:
 | 
			
		||||
        if node.next and node.preserve and node.next.preserve:
 | 
			
		||||
            node.string += node.next.string
 | 
			
		||||
            node.next = node.next.next
 | 
			
		||||
        node = node.next
 | 
			
		||||
        if node is None: break
 | 
			
		||||
        if node is None:
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
    # 将前后断行符脱离
 | 
			
		||||
    node = root
 | 
			
		||||
    prev_node = None
 | 
			
		||||
    while True:
 | 
			
		||||
        if not node.preserve:
 | 
			
		||||
            lstriped_ = node.string.lstrip().lstrip('\n')
 | 
			
		||||
            if (prev_node is not None) and (prev_node.preserve) and (len(lstriped_)!=len(node.string)):
 | 
			
		||||
                prev_node.string += node.string[:-len(lstriped_)]
 | 
			
		||||
            lstriped_ = node.string.lstrip().lstrip("\n")
 | 
			
		||||
            if (
 | 
			
		||||
                (prev_node is not None)
 | 
			
		||||
                and (prev_node.preserve)
 | 
			
		||||
                and (len(lstriped_) != len(node.string))
 | 
			
		||||
            ):
 | 
			
		||||
                prev_node.string += node.string[: -len(lstriped_)]
 | 
			
		||||
                node.string = lstriped_
 | 
			
		||||
            rstriped_ = node.string.rstrip().rstrip('\n')
 | 
			
		||||
            if (node.next is not None) and (node.next.preserve) and (len(rstriped_)!=len(node.string)):
 | 
			
		||||
                node.next.string = node.string[len(rstriped_):] + node.next.string
 | 
			
		||||
            rstriped_ = node.string.rstrip().rstrip("\n")
 | 
			
		||||
            if (
 | 
			
		||||
                (node.next is not None)
 | 
			
		||||
                and (node.next.preserve)
 | 
			
		||||
                and (len(rstriped_) != len(node.string))
 | 
			
		||||
            ):
 | 
			
		||||
                node.next.string = node.string[len(rstriped_) :] + node.next.string
 | 
			
		||||
                node.string = rstriped_
 | 
			
		||||
        # =====
 | 
			
		||||
        # =-=-=
 | 
			
		||||
        prev_node = node
 | 
			
		||||
        node = node.next
 | 
			
		||||
        if node is None: break
 | 
			
		||||
        if node is None:
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
    # 标注节点的行数范围
 | 
			
		||||
    node = root
 | 
			
		||||
    n_line = 0
 | 
			
		||||
    expansion = 2
 | 
			
		||||
    while True:
 | 
			
		||||
        n_l = node.string.count('\n')
 | 
			
		||||
        node.range = [n_line-expansion, n_line+n_l+expansion]   # 失败时,扭转的范围
 | 
			
		||||
        n_line = n_line+n_l
 | 
			
		||||
        n_l = node.string.count("\n")
 | 
			
		||||
        node.range = [n_line - expansion, n_line + n_l + expansion]  # 失败时,扭转的范围
 | 
			
		||||
        n_line = n_line + n_l
 | 
			
		||||
        node = node.next
 | 
			
		||||
        if node is None: break
 | 
			
		||||
        if node is None:
 | 
			
		||||
            break
 | 
			
		||||
    return root
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -131,12 +155,14 @@ def set_forbidden_text(text, mask, pattern, flags=0):
 | 
			
		||||
    you can mask out (mask = PRESERVE so that text become untouchable for GPT)
 | 
			
		||||
    everything between "\begin{equation}" and "\end{equation}"
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(pattern, list): pattern = '|'.join(pattern)
 | 
			
		||||
    if isinstance(pattern, list):
 | 
			
		||||
        pattern = "|".join(pattern)
 | 
			
		||||
    pattern_compile = re.compile(pattern, flags)
 | 
			
		||||
    for res in pattern_compile.finditer(text):
 | 
			
		||||
        mask[res.span()[0]:res.span()[1]] = PRESERVE
 | 
			
		||||
        mask[res.span()[0] : res.span()[1]] = PRESERVE
 | 
			
		||||
    return text, mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
 | 
			
		||||
    """
 | 
			
		||||
    Move area out of preserve area (make text editable for GPT)
 | 
			
		||||
@ -144,17 +170,19 @@ def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
 | 
			
		||||
    e.g.
 | 
			
		||||
    \begin{abstract} blablablablablabla. \end{abstract}
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(pattern, list): pattern = '|'.join(pattern)
 | 
			
		||||
    if isinstance(pattern, list):
 | 
			
		||||
        pattern = "|".join(pattern)
 | 
			
		||||
    pattern_compile = re.compile(pattern, flags)
 | 
			
		||||
    for res in pattern_compile.finditer(text):
 | 
			
		||||
        if not forbid_wrapper:
 | 
			
		||||
            mask[res.span()[0]:res.span()[1]] = TRANSFORM
 | 
			
		||||
            mask[res.span()[0] : res.span()[1]] = TRANSFORM
 | 
			
		||||
        else:
 | 
			
		||||
            mask[res.regs[0][0]: res.regs[1][0]] = PRESERVE   # '\\begin{abstract}'
 | 
			
		||||
            mask[res.regs[1][0]: res.regs[1][1]] = TRANSFORM   # abstract
 | 
			
		||||
            mask[res.regs[1][1]: res.regs[0][1]] = PRESERVE   # abstract
 | 
			
		||||
            mask[res.regs[0][0] : res.regs[1][0]] = PRESERVE  # '\\begin{abstract}'
 | 
			
		||||
            mask[res.regs[1][0] : res.regs[1][1]] = TRANSFORM  # abstract
 | 
			
		||||
            mask[res.regs[1][1] : res.regs[0][1]] = PRESERVE  # abstract
 | 
			
		||||
    return text, mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
 | 
			
		||||
    """
 | 
			
		||||
    Add a preserve text area in this paper (text become untouchable for GPT).
 | 
			
		||||
@ -166,16 +194,22 @@ def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
 | 
			
		||||
    for res in pattern_compile.finditer(text):
 | 
			
		||||
        brace_level = -1
 | 
			
		||||
        p = begin = end = res.regs[0][0]
 | 
			
		||||
        for _ in range(1024*16):
 | 
			
		||||
            if text[p] == '}' and brace_level == 0: break
 | 
			
		||||
            elif text[p] == '}':  brace_level -= 1
 | 
			
		||||
            elif text[p] == '{':  brace_level += 1
 | 
			
		||||
        for _ in range(1024 * 16):
 | 
			
		||||
            if text[p] == "}" and brace_level == 0:
 | 
			
		||||
                break
 | 
			
		||||
            elif text[p] == "}":
 | 
			
		||||
                brace_level -= 1
 | 
			
		||||
            elif text[p] == "{":
 | 
			
		||||
                brace_level += 1
 | 
			
		||||
            p += 1
 | 
			
		||||
        end = p+1
 | 
			
		||||
        end = p + 1
 | 
			
		||||
        mask[begin:end] = PRESERVE
 | 
			
		||||
    return text, mask
 | 
			
		||||
 | 
			
		||||
def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wrapper=True):
 | 
			
		||||
 | 
			
		||||
def reverse_forbidden_text_careful_brace(
 | 
			
		||||
    text, mask, pattern, flags=0, forbid_wrapper=True
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Move area out of preserve area (make text editable for GPT)
 | 
			
		||||
    count the number of the braces so as to catch compelete text area.
 | 
			
		||||
@ -186,39 +220,57 @@ def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wr
 | 
			
		||||
    for res in pattern_compile.finditer(text):
 | 
			
		||||
        brace_level = 0
 | 
			
		||||
        p = begin = end = res.regs[1][0]
 | 
			
		||||
        for _ in range(1024*16):
 | 
			
		||||
            if text[p] == '}' and brace_level == 0: break
 | 
			
		||||
            elif text[p] == '}':  brace_level -= 1
 | 
			
		||||
            elif text[p] == '{':  brace_level += 1
 | 
			
		||||
        for _ in range(1024 * 16):
 | 
			
		||||
            if text[p] == "}" and brace_level == 0:
 | 
			
		||||
                break
 | 
			
		||||
            elif text[p] == "}":
 | 
			
		||||
                brace_level -= 1
 | 
			
		||||
            elif text[p] == "{":
 | 
			
		||||
                brace_level += 1
 | 
			
		||||
            p += 1
 | 
			
		||||
        end = p
 | 
			
		||||
        mask[begin:end] = TRANSFORM
 | 
			
		||||
        if forbid_wrapper:
 | 
			
		||||
            mask[res.regs[0][0]:begin] = PRESERVE
 | 
			
		||||
            mask[end:res.regs[0][1]] = PRESERVE
 | 
			
		||||
            mask[res.regs[0][0] : begin] = PRESERVE
 | 
			
		||||
            mask[end : res.regs[0][1]] = PRESERVE
 | 
			
		||||
    return text, mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_forbidden_text_begin_end(text, mask, pattern, flags=0, limit_n_lines=42):
 | 
			
		||||
    """
 | 
			
		||||
    Find all \begin{} ... \end{} text block that with less than limit_n_lines lines.
 | 
			
		||||
    Add it to preserve area
 | 
			
		||||
    """
 | 
			
		||||
    pattern_compile = re.compile(pattern, flags)
 | 
			
		||||
 | 
			
		||||
    def search_with_line_limit(text, mask):
 | 
			
		||||
        for res in pattern_compile.finditer(text):
 | 
			
		||||
            cmd = res.group(1)  # begin{what}
 | 
			
		||||
            this = res.group(2) # content between begin and end
 | 
			
		||||
            this_mask = mask[res.regs[2][0]:res.regs[2][1]]
 | 
			
		||||
            white_list = ['document', 'abstract', 'lemma', 'definition', 'sproof', 
 | 
			
		||||
                          'em', 'emph', 'textit', 'textbf', 'itemize', 'enumerate']
 | 
			
		||||
            if (cmd in white_list) or this.count('\n') >= limit_n_lines: # use a magical number 42
 | 
			
		||||
            this = res.group(2)  # content between begin and end
 | 
			
		||||
            this_mask = mask[res.regs[2][0] : res.regs[2][1]]
 | 
			
		||||
            white_list = [
 | 
			
		||||
                "document",
 | 
			
		||||
                "abstract",
 | 
			
		||||
                "lemma",
 | 
			
		||||
                "definition",
 | 
			
		||||
                "sproof",
 | 
			
		||||
                "em",
 | 
			
		||||
                "emph",
 | 
			
		||||
                "textit",
 | 
			
		||||
                "textbf",
 | 
			
		||||
                "itemize",
 | 
			
		||||
                "enumerate",
 | 
			
		||||
            ]
 | 
			
		||||
            if (cmd in white_list) or this.count(
 | 
			
		||||
                "\n"
 | 
			
		||||
            ) >= limit_n_lines:  # use a magical number 42
 | 
			
		||||
                this, this_mask = search_with_line_limit(this, this_mask)
 | 
			
		||||
                mask[res.regs[2][0]:res.regs[2][1]] = this_mask
 | 
			
		||||
                mask[res.regs[2][0] : res.regs[2][1]] = this_mask
 | 
			
		||||
            else:
 | 
			
		||||
                mask[res.regs[0][0]:res.regs[0][1]] = PRESERVE
 | 
			
		||||
                mask[res.regs[0][0] : res.regs[0][1]] = PRESERVE
 | 
			
		||||
        return text, mask
 | 
			
		||||
    return search_with_line_limit(text, mask) 
 | 
			
		||||
 | 
			
		||||
    return search_with_line_limit(text, mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
@ -227,6 +279,7 @@ Latex Merge File
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_main_tex_file(file_manifest, mode):
 | 
			
		||||
    """
 | 
			
		||||
    在多Tex文档中,寻找主文件,必须包含documentclass,返回找到的第一个。
 | 
			
		||||
@ -234,27 +287,36 @@ def find_main_tex_file(file_manifest, mode):
 | 
			
		||||
    """
 | 
			
		||||
    canidates = []
 | 
			
		||||
    for texf in file_manifest:
 | 
			
		||||
        if os.path.basename(texf).startswith('merge'):
 | 
			
		||||
        if os.path.basename(texf).startswith("merge"):
 | 
			
		||||
            continue
 | 
			
		||||
        with open(texf, 'r', encoding='utf8', errors='ignore') as f:
 | 
			
		||||
        with open(texf, "r", encoding="utf8", errors="ignore") as f:
 | 
			
		||||
            file_content = f.read()
 | 
			
		||||
        if r'\documentclass' in file_content:
 | 
			
		||||
        if r"\documentclass" in file_content:
 | 
			
		||||
            canidates.append(texf)
 | 
			
		||||
        else:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
    if len(canidates) == 0:
 | 
			
		||||
        raise RuntimeError('无法找到一个主Tex文件(包含documentclass关键字)')
 | 
			
		||||
        raise RuntimeError("无法找到一个主Tex文件(包含documentclass关键字)")
 | 
			
		||||
    elif len(canidates) == 1:
 | 
			
		||||
        return canidates[0]
 | 
			
		||||
    else: # if len(canidates) >= 2 通过一些Latex模板中常见(但通常不会出现在正文)的单词,对不同latex源文件扣分,取评分最高者返回
 | 
			
		||||
    else:  # if len(canidates) >= 2 通过一些Latex模板中常见(但通常不会出现在正文)的单词,对不同latex源文件扣分,取评分最高者返回
 | 
			
		||||
        canidates_score = []
 | 
			
		||||
        # 给出一些判定模板文档的词作为扣分项
 | 
			
		||||
        unexpected_words = ['\\LaTeX', 'manuscript', 'Guidelines', 'font', 'citations', 'rejected', 'blind review', 'reviewers']
 | 
			
		||||
        expected_words = ['\\input', '\\ref', '\\cite']
 | 
			
		||||
        unexpected_words = [
 | 
			
		||||
            "\\LaTeX",
 | 
			
		||||
            "manuscript",
 | 
			
		||||
            "Guidelines",
 | 
			
		||||
            "font",
 | 
			
		||||
            "citations",
 | 
			
		||||
            "rejected",
 | 
			
		||||
            "blind review",
 | 
			
		||||
            "reviewers",
 | 
			
		||||
        ]
 | 
			
		||||
        expected_words = ["\\input", "\\ref", "\\cite"]
 | 
			
		||||
        for texf in canidates:
 | 
			
		||||
            canidates_score.append(0)
 | 
			
		||||
            with open(texf, 'r', encoding='utf8', errors='ignore') as f:
 | 
			
		||||
            with open(texf, "r", encoding="utf8", errors="ignore") as f:
 | 
			
		||||
                file_content = f.read()
 | 
			
		||||
                file_content = rm_comments(file_content)
 | 
			
		||||
            for uw in unexpected_words:
 | 
			
		||||
@ -263,9 +325,10 @@ def find_main_tex_file(file_manifest, mode):
 | 
			
		||||
            for uw in expected_words:
 | 
			
		||||
                if uw in file_content:
 | 
			
		||||
                    canidates_score[-1] += 1
 | 
			
		||||
        select = np.argmax(canidates_score) # 取评分最高者返回
 | 
			
		||||
        select = np.argmax(canidates_score)  # 取评分最高者返回
 | 
			
		||||
        return canidates[select]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rm_comments(main_file):
 | 
			
		||||
    new_file_remove_comment_lines = []
 | 
			
		||||
    for l in main_file.splitlines():
 | 
			
		||||
@ -274,30 +337,39 @@ def rm_comments(main_file):
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            new_file_remove_comment_lines.append(l)
 | 
			
		||||
    main_file = '\n'.join(new_file_remove_comment_lines)
 | 
			
		||||
    main_file = "\n".join(new_file_remove_comment_lines)
 | 
			
		||||
    # main_file = re.sub(r"\\include{(.*?)}", r"\\input{\1}", main_file)  # 将 \include 命令转换为 \input 命令
 | 
			
		||||
    main_file = re.sub(r'(?<!\\)%.*', '', main_file)  # 使用正则表达式查找半行注释, 并替换为空字符串
 | 
			
		||||
    main_file = re.sub(r"(?<!\\)%.*", "", main_file)  # 使用正则表达式查找半行注释, 并替换为空字符串
 | 
			
		||||
    return main_file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_tex_file_ignore_case(fp):
 | 
			
		||||
    dir_name = os.path.dirname(fp)
 | 
			
		||||
    base_name = os.path.basename(fp)
 | 
			
		||||
    # 如果输入的文件路径是正确的
 | 
			
		||||
    if os.path.isfile(pj(dir_name, base_name)): return pj(dir_name, base_name)
 | 
			
		||||
    if os.path.isfile(pj(dir_name, base_name)):
 | 
			
		||||
        return pj(dir_name, base_name)
 | 
			
		||||
    # 如果不正确,试着加上.tex后缀试试
 | 
			
		||||
    if not base_name.endswith('.tex'): base_name+='.tex'
 | 
			
		||||
    if os.path.isfile(pj(dir_name, base_name)): return pj(dir_name, base_name)
 | 
			
		||||
    if not base_name.endswith(".tex"):
 | 
			
		||||
        base_name += ".tex"
 | 
			
		||||
    if os.path.isfile(pj(dir_name, base_name)):
 | 
			
		||||
        return pj(dir_name, base_name)
 | 
			
		||||
    # 如果还找不到,解除大小写限制,再试一次
 | 
			
		||||
    import glob
 | 
			
		||||
    for f in glob.glob(dir_name+'/*.tex'):
 | 
			
		||||
 | 
			
		||||
    for f in glob.glob(dir_name + "/*.tex"):
 | 
			
		||||
        base_name_s = os.path.basename(fp)
 | 
			
		||||
        base_name_f = os.path.basename(f)
 | 
			
		||||
        if base_name_s.lower() == base_name_f.lower(): return f
 | 
			
		||||
        if base_name_s.lower() == base_name_f.lower():
 | 
			
		||||
            return f
 | 
			
		||||
        # 试着加上.tex后缀试试
 | 
			
		||||
        if not base_name_s.endswith('.tex'): base_name_s+='.tex'
 | 
			
		||||
        if base_name_s.lower() == base_name_f.lower(): return f
 | 
			
		||||
        if not base_name_s.endswith(".tex"):
 | 
			
		||||
            base_name_s += ".tex"
 | 
			
		||||
        if base_name_s.lower() == base_name_f.lower():
 | 
			
		||||
            return f
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_tex_files_(project_foler, main_file, mode):
 | 
			
		||||
    """
 | 
			
		||||
    Merge Tex project recrusively
 | 
			
		||||
@ -309,18 +381,18 @@ def merge_tex_files_(project_foler, main_file, mode):
 | 
			
		||||
        fp_ = find_tex_file_ignore_case(fp)
 | 
			
		||||
        if fp_:
 | 
			
		||||
            try:
 | 
			
		||||
                with open(fp_, 'r', encoding='utf-8', errors='replace') as fx: c = fx.read()
 | 
			
		||||
                with open(fp_, "r", encoding="utf-8", errors="replace") as fx:
 | 
			
		||||
                    c = fx.read()
 | 
			
		||||
            except:
 | 
			
		||||
                c = f"\n\nWarning from GPT-Academic: LaTex source file is missing!\n\n"
 | 
			
		||||
        else:
 | 
			
		||||
            raise RuntimeError(f'找不到{fp},Tex源文件缺失!')
 | 
			
		||||
            raise RuntimeError(f"找不到{fp},Tex源文件缺失!")
 | 
			
		||||
        c = merge_tex_files_(project_foler, c, mode)
 | 
			
		||||
        main_file = main_file[:s.span()[0]] + c + main_file[s.span()[1]:]
 | 
			
		||||
        main_file = main_file[: s.span()[0]] + c + main_file[s.span()[1] :]
 | 
			
		||||
    return main_file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_title_and_abs(main_file):
 | 
			
		||||
 | 
			
		||||
    def extract_abstract_1(text):
 | 
			
		||||
        pattern = r"\\abstract\{(.*?)\}"
 | 
			
		||||
        match = re.search(pattern, text, re.DOTALL)
 | 
			
		||||
@ -362,21 +434,30 @@ def merge_tex_files(project_foler, main_file, mode):
 | 
			
		||||
    main_file = merge_tex_files_(project_foler, main_file, mode)
 | 
			
		||||
    main_file = rm_comments(main_file)
 | 
			
		||||
 | 
			
		||||
    if mode == 'translate_zh':
 | 
			
		||||
    if mode == "translate_zh":
 | 
			
		||||
        # find paper documentclass
 | 
			
		||||
        pattern = re.compile(r'\\documentclass.*\n')
 | 
			
		||||
        pattern = re.compile(r"\\documentclass.*\n")
 | 
			
		||||
        match = pattern.search(main_file)
 | 
			
		||||
        assert match is not None, "Cannot find documentclass statement!"
 | 
			
		||||
        position = match.end()
 | 
			
		||||
        add_ctex = '\\usepackage{ctex}\n'
 | 
			
		||||
        add_url = '\\usepackage{url}\n' if '{url}' not in main_file else ''
 | 
			
		||||
        add_ctex = "\\usepackage{ctex}\n"
 | 
			
		||||
        add_url = "\\usepackage{url}\n" if "{url}" not in main_file else ""
 | 
			
		||||
        main_file = main_file[:position] + add_ctex + add_url + main_file[position:]
 | 
			
		||||
        # fontset=windows
 | 
			
		||||
        import platform
 | 
			
		||||
        main_file = re.sub(r"\\documentclass\[(.*?)\]{(.*?)}", r"\\documentclass[\1,fontset=windows,UTF8]{\2}",main_file)
 | 
			
		||||
        main_file = re.sub(r"\\documentclass{(.*?)}", r"\\documentclass[fontset=windows,UTF8]{\1}",main_file)
 | 
			
		||||
 | 
			
		||||
        main_file = re.sub(
 | 
			
		||||
            r"\\documentclass\[(.*?)\]{(.*?)}",
 | 
			
		||||
            r"\\documentclass[\1,fontset=windows,UTF8]{\2}",
 | 
			
		||||
            main_file,
 | 
			
		||||
        )
 | 
			
		||||
        main_file = re.sub(
 | 
			
		||||
            r"\\documentclass{(.*?)}",
 | 
			
		||||
            r"\\documentclass[fontset=windows,UTF8]{\1}",
 | 
			
		||||
            main_file,
 | 
			
		||||
        )
 | 
			
		||||
        # find paper abstract
 | 
			
		||||
        pattern_opt1 = re.compile(r'\\begin\{abstract\}.*\n')
 | 
			
		||||
        pattern_opt1 = re.compile(r"\\begin\{abstract\}.*\n")
 | 
			
		||||
        pattern_opt2 = re.compile(r"\\abstract\{(.*?)\}", flags=re.DOTALL)
 | 
			
		||||
        match_opt1 = pattern_opt1.search(main_file)
 | 
			
		||||
        match_opt2 = pattern_opt2.search(main_file)
 | 
			
		||||
@ -385,7 +466,9 @@ def merge_tex_files(project_foler, main_file, mode):
 | 
			
		||||
            main_file = insert_abstract(main_file)
 | 
			
		||||
        match_opt1 = pattern_opt1.search(main_file)
 | 
			
		||||
        match_opt2 = pattern_opt2.search(main_file)
 | 
			
		||||
        assert (match_opt1 is not None) or (match_opt2 is not None), "Cannot find paper abstract section!"
 | 
			
		||||
        assert (match_opt1 is not None) or (
 | 
			
		||||
            match_opt2 is not None
 | 
			
		||||
        ), "Cannot find paper abstract section!"
 | 
			
		||||
    return main_file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -395,6 +478,7 @@ The GPT-Academic program cannot find abstract section in this paper.
 | 
			
		||||
\end{abstract}
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def insert_abstract(tex_content):
 | 
			
		||||
    if "\\maketitle" in tex_content:
 | 
			
		||||
        # find the position of "\maketitle"
 | 
			
		||||
@ -402,7 +486,13 @@ def insert_abstract(tex_content):
 | 
			
		||||
        # find the nearest ending line
 | 
			
		||||
        end_line_index = tex_content.find("\n", find_index)
 | 
			
		||||
        # insert "abs_str" on the next line
 | 
			
		||||
        modified_tex = tex_content[:end_line_index+1] + '\n\n' + insert_missing_abs_str + '\n\n' + tex_content[end_line_index+1:]
 | 
			
		||||
        modified_tex = (
 | 
			
		||||
            tex_content[: end_line_index + 1]
 | 
			
		||||
            + "\n\n"
 | 
			
		||||
            + insert_missing_abs_str
 | 
			
		||||
            + "\n\n"
 | 
			
		||||
            + tex_content[end_line_index + 1 :]
 | 
			
		||||
        )
 | 
			
		||||
        return modified_tex
 | 
			
		||||
    elif r"\begin{document}" in tex_content:
 | 
			
		||||
        # find the position of "\maketitle"
 | 
			
		||||
@ -410,16 +500,25 @@ def insert_abstract(tex_content):
 | 
			
		||||
        # find the nearest ending line
 | 
			
		||||
        end_line_index = tex_content.find("\n", find_index)
 | 
			
		||||
        # insert "abs_str" on the next line
 | 
			
		||||
        modified_tex = tex_content[:end_line_index+1] + '\n\n' + insert_missing_abs_str + '\n\n' + tex_content[end_line_index+1:]
 | 
			
		||||
        modified_tex = (
 | 
			
		||||
            tex_content[: end_line_index + 1]
 | 
			
		||||
            + "\n\n"
 | 
			
		||||
            + insert_missing_abs_str
 | 
			
		||||
            + "\n\n"
 | 
			
		||||
            + tex_content[end_line_index + 1 :]
 | 
			
		||||
        )
 | 
			
		||||
        return modified_tex
 | 
			
		||||
    else:
 | 
			
		||||
        return tex_content
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
 | 
			
		||||
Post process
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mod_inbraket(match):
 | 
			
		||||
    """
 | 
			
		||||
    为啥chatgpt会把cite里面的逗号换成中文逗号呀
 | 
			
		||||
@ -428,11 +527,12 @@ def mod_inbraket(match):
 | 
			
		||||
    cmd = match.group(1)
 | 
			
		||||
    str_to_modify = match.group(2)
 | 
			
		||||
    # modify the matched string
 | 
			
		||||
    str_to_modify = str_to_modify.replace(':', ':')    # 前面是中文冒号,后面是英文冒号
 | 
			
		||||
    str_to_modify = str_to_modify.replace(',', ',')    # 前面是中文逗号,后面是英文逗号
 | 
			
		||||
    str_to_modify = str_to_modify.replace(":", ":")  # 前面是中文冒号,后面是英文冒号
 | 
			
		||||
    str_to_modify = str_to_modify.replace(",", ",")  # 前面是中文逗号,后面是英文逗号
 | 
			
		||||
    # str_to_modify = 'BOOM'
 | 
			
		||||
    return "\\" + cmd + "{" + str_to_modify + "}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fix_content(final_tex, node_string):
 | 
			
		||||
    """
 | 
			
		||||
    Fix common GPT errors to increase success rate
 | 
			
		||||
@ -443,10 +543,10 @@ def fix_content(final_tex, node_string):
 | 
			
		||||
    final_tex = re.sub(r"\\([a-z]{2,10})\{([^\}]*?)\}", mod_inbraket, string=final_tex)
 | 
			
		||||
 | 
			
		||||
    if "Traceback" in final_tex and "[Local Message]" in final_tex:
 | 
			
		||||
        final_tex = node_string # 出问题了,还原原文
 | 
			
		||||
    if node_string.count('\\begin') != final_tex.count('\\begin'):
 | 
			
		||||
        final_tex = node_string # 出问题了,还原原文
 | 
			
		||||
    if node_string.count('\_') > 0 and node_string.count('\_') > final_tex.count('\_'):
 | 
			
		||||
        final_tex = node_string  # 出问题了,还原原文
 | 
			
		||||
    if node_string.count("\\begin") != final_tex.count("\\begin"):
 | 
			
		||||
        final_tex = node_string  # 出问题了,还原原文
 | 
			
		||||
    if node_string.count("\_") > 0 and node_string.count("\_") > final_tex.count("\_"):
 | 
			
		||||
        # walk and replace any _ without \
 | 
			
		||||
        final_tex = re.sub(r"(?<!\\)_", "\\_", final_tex)
 | 
			
		||||
 | 
			
		||||
@ -454,24 +554,32 @@ def fix_content(final_tex, node_string):
 | 
			
		||||
        # this function count the number of { and }
 | 
			
		||||
        brace_level = 0
 | 
			
		||||
        for c in string:
 | 
			
		||||
            if c == "{": brace_level += 1
 | 
			
		||||
            elif c == "}": brace_level -= 1
 | 
			
		||||
            if c == "{":
 | 
			
		||||
                brace_level += 1
 | 
			
		||||
            elif c == "}":
 | 
			
		||||
                brace_level -= 1
 | 
			
		||||
        return brace_level
 | 
			
		||||
 | 
			
		||||
    def join_most(tex_t, tex_o):
 | 
			
		||||
        # this function join translated string and original string when something goes wrong
 | 
			
		||||
        p_t = 0
 | 
			
		||||
        p_o = 0
 | 
			
		||||
 | 
			
		||||
        def find_next(string, chars, begin):
 | 
			
		||||
            p = begin
 | 
			
		||||
            while p < len(string):
 | 
			
		||||
                if string[p] in chars: return p, string[p]
 | 
			
		||||
                if string[p] in chars:
 | 
			
		||||
                    return p, string[p]
 | 
			
		||||
                p += 1
 | 
			
		||||
            return None, None
 | 
			
		||||
 | 
			
		||||
        while True:
 | 
			
		||||
            res1, char = find_next(tex_o, ['{','}'], p_o)
 | 
			
		||||
            if res1 is None: break
 | 
			
		||||
            res1, char = find_next(tex_o, ["{", "}"], p_o)
 | 
			
		||||
            if res1 is None:
 | 
			
		||||
                break
 | 
			
		||||
            res2, char = find_next(tex_t, [char], p_t)
 | 
			
		||||
            if res2 is None: break
 | 
			
		||||
            if res2 is None:
 | 
			
		||||
                break
 | 
			
		||||
            p_o = res1 + 1
 | 
			
		||||
            p_t = res2 + 1
 | 
			
		||||
        return tex_t[:p_t] + tex_o[p_o:]
 | 
			
		||||
@ -481,9 +589,13 @@ def fix_content(final_tex, node_string):
 | 
			
		||||
        final_tex = join_most(final_tex, node_string)
 | 
			
		||||
    return final_tex
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compile_latex_with_timeout(command, cwd, timeout=60):
 | 
			
		||||
    import subprocess
 | 
			
		||||
    process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd)
 | 
			
		||||
 | 
			
		||||
    process = subprocess.Popen(
 | 
			
		||||
        command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd
 | 
			
		||||
    )
 | 
			
		||||
    try:
 | 
			
		||||
        stdout, stderr = process.communicate(timeout=timeout)
 | 
			
		||||
    except subprocess.TimeoutExpired:
 | 
			
		||||
@ -493,43 +605,52 @@ def compile_latex_with_timeout(command, cwd, timeout=60):
 | 
			
		||||
        return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_in_subprocess_wrapper_func(func, args, kwargs, return_dict, exception_dict):
 | 
			
		||||
    import sys
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        result = func(*args, **kwargs)
 | 
			
		||||
        return_dict['result'] = result
 | 
			
		||||
        return_dict["result"] = result
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        exc_info = sys.exc_info()
 | 
			
		||||
        exception_dict['exception'] = exc_info
 | 
			
		||||
        exception_dict["exception"] = exc_info
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_in_subprocess(func):
 | 
			
		||||
    import multiprocessing
 | 
			
		||||
 | 
			
		||||
    def wrapper(*args, **kwargs):
 | 
			
		||||
        return_dict = multiprocessing.Manager().dict()
 | 
			
		||||
        exception_dict = multiprocessing.Manager().dict()
 | 
			
		||||
        process = multiprocessing.Process(target=run_in_subprocess_wrapper_func, 
 | 
			
		||||
                                            args=(func, args, kwargs, return_dict, exception_dict))
 | 
			
		||||
        process = multiprocessing.Process(
 | 
			
		||||
            target=run_in_subprocess_wrapper_func,
 | 
			
		||||
            args=(func, args, kwargs, return_dict, exception_dict),
 | 
			
		||||
        )
 | 
			
		||||
        process.start()
 | 
			
		||||
        process.join()
 | 
			
		||||
        process.close()
 | 
			
		||||
        if 'exception' in exception_dict:
 | 
			
		||||
        if "exception" in exception_dict:
 | 
			
		||||
            # ooops, the subprocess ran into an exception
 | 
			
		||||
            exc_info = exception_dict['exception']
 | 
			
		||||
            exc_info = exception_dict["exception"]
 | 
			
		||||
            raise exc_info[1].with_traceback(exc_info[2])
 | 
			
		||||
        if 'result' in return_dict.keys():
 | 
			
		||||
        if "result" in return_dict.keys():
 | 
			
		||||
            # If the subprocess ran successfully, return the result
 | 
			
		||||
            return return_dict['result']
 | 
			
		||||
            return return_dict["result"]
 | 
			
		||||
 | 
			
		||||
    return wrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _merge_pdfs(pdf1_path, pdf2_path, output_path):
 | 
			
		||||
    import PyPDF2   # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
 | 
			
		||||
    import PyPDF2  # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
 | 
			
		||||
 | 
			
		||||
    Percent = 0.95
 | 
			
		||||
    # raise RuntimeError('PyPDF2 has a serious memory leak problem, please use other tools to merge PDF files.')
 | 
			
		||||
    # Open the first PDF file
 | 
			
		||||
    with open(pdf1_path, 'rb') as pdf1_file:
 | 
			
		||||
    with open(pdf1_path, "rb") as pdf1_file:
 | 
			
		||||
        pdf1_reader = PyPDF2.PdfFileReader(pdf1_file)
 | 
			
		||||
        # Open the second PDF file
 | 
			
		||||
        with open(pdf2_path, 'rb') as pdf2_file:
 | 
			
		||||
        with open(pdf2_path, "rb") as pdf2_file:
 | 
			
		||||
            pdf2_reader = PyPDF2.PdfFileReader(pdf2_file)
 | 
			
		||||
            # Create a new PDF file to store the merged pages
 | 
			
		||||
            output_writer = PyPDF2.PdfFileWriter()
 | 
			
		||||
@ -549,14 +670,25 @@ def _merge_pdfs(pdf1_path, pdf2_path, output_path):
 | 
			
		||||
                    page2 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
 | 
			
		||||
                # Create a new empty page with double width
 | 
			
		||||
                new_page = PyPDF2.PageObject.createBlankPage(
 | 
			
		||||
                    width = int(int(page1.mediaBox.getWidth()) + int(page2.mediaBox.getWidth()) * Percent),
 | 
			
		||||
                    height = max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight())
 | 
			
		||||
                    width=int(
 | 
			
		||||
                        int(page1.mediaBox.getWidth())
 | 
			
		||||
                        + int(page2.mediaBox.getWidth()) * Percent
 | 
			
		||||
                    ),
 | 
			
		||||
                    height=max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight()),
 | 
			
		||||
                )
 | 
			
		||||
                new_page.mergeTranslatedPage(page1, 0, 0)
 | 
			
		||||
                new_page.mergeTranslatedPage(page2, int(int(page1.mediaBox.getWidth())-int(page2.mediaBox.getWidth())* (1-Percent)), 0)
 | 
			
		||||
                new_page.mergeTranslatedPage(
 | 
			
		||||
                    page2,
 | 
			
		||||
                    int(
 | 
			
		||||
                        int(page1.mediaBox.getWidth())
 | 
			
		||||
                        - int(page2.mediaBox.getWidth()) * (1 - Percent)
 | 
			
		||||
                    ),
 | 
			
		||||
                    0,
 | 
			
		||||
                )
 | 
			
		||||
                output_writer.addPage(new_page)
 | 
			
		||||
            # Save the merged PDF file
 | 
			
		||||
            with open(output_path, 'wb') as output_file:
 | 
			
		||||
            with open(output_path, "wb") as output_file:
 | 
			
		||||
                output_writer.write(output_file)
 | 
			
		||||
 | 
			
		||||
merge_pdfs = run_in_subprocess(_merge_pdfs) # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
 | 
			
		||||
 | 
			
		||||
merge_pdfs = run_in_subprocess(_merge_pdfs)  # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
 | 
			
		||||
 | 
			
		||||
@ -352,9 +352,9 @@ def step_1_core_key_translate():
 | 
			
		||||
        chinese_core_keys_norepeat_mapping.update({k:cached_translation[k]})
 | 
			
		||||
    chinese_core_keys_norepeat_mapping = dict(sorted(chinese_core_keys_norepeat_mapping.items(), key=lambda x: -len(x[0])))
 | 
			
		||||
 | 
			
		||||
    # ===============================================
 | 
			
		||||
    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
    # copy
 | 
			
		||||
    # ===============================================
 | 
			
		||||
    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
    def copy_source_code():
 | 
			
		||||
 | 
			
		||||
        from toolbox import get_conf
 | 
			
		||||
@ -367,9 +367,9 @@ def step_1_core_key_translate():
 | 
			
		||||
        shutil.copytree('./', backup_dir, ignore=lambda x, y: blacklist)
 | 
			
		||||
    copy_source_code()
 | 
			
		||||
 | 
			
		||||
    # ===============================================
 | 
			
		||||
    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
    # primary key replace
 | 
			
		||||
    # ===============================================
 | 
			
		||||
    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
    directory_path = f'./multi-language/{LANG}/'
 | 
			
		||||
    for root, dirs, files in os.walk(directory_path):
 | 
			
		||||
        for file in files:
 | 
			
		||||
@ -389,9 +389,9 @@ def step_1_core_key_translate():
 | 
			
		||||
 | 
			
		||||
def step_2_core_key_translate():
 | 
			
		||||
 | 
			
		||||
    # =================================================================================================
 | 
			
		||||
    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
 | 
			
		||||
    # step2 
 | 
			
		||||
    # =================================================================================================
 | 
			
		||||
    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
 | 
			
		||||
 | 
			
		||||
    def load_string(strings, string_input):
 | 
			
		||||
        string_ = string_input.strip().strip(',').strip().strip('.').strip()
 | 
			
		||||
@ -492,9 +492,9 @@ def step_2_core_key_translate():
 | 
			
		||||
    cached_translation.update(read_map_from_json(language=LANG_STD))
 | 
			
		||||
    cached_translation = dict(sorted(cached_translation.items(), key=lambda x: -len(x[0])))
 | 
			
		||||
 | 
			
		||||
    # ===============================================
 | 
			
		||||
    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
    # literal key replace
 | 
			
		||||
    # ===============================================
 | 
			
		||||
    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
    directory_path = f'./multi-language/{LANG}/'
 | 
			
		||||
    for root, dirs, files in os.walk(directory_path):
 | 
			
		||||
        for file in files:
 | 
			
		||||
 | 
			
		||||
@ -1,16 +1,17 @@
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
第一部分:来自EdgeGPT.py
 | 
			
		||||
https://github.com/acheong08/EdgeGPT
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
from .edge_gpt_free import Chatbot as NewbingChatbot
 | 
			
		||||
 | 
			
		||||
load_message = "等待NewBing响应。"
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
第二部分:子进程Worker(调用主体)
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
import time
 | 
			
		||||
import json
 | 
			
		||||
@ -22,19 +23,30 @@ import threading
 | 
			
		||||
from toolbox import update_ui, get_conf, trimmed_format_exc
 | 
			
		||||
from multiprocessing import Process, Pipe
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_newbing_out(s):
 | 
			
		||||
    pattern = r'\^(\d+)\^' # 匹配^数字^
 | 
			
		||||
    sub = lambda m: '('+m.group(1)+')' # 将匹配到的数字作为替换值
 | 
			
		||||
    result = re.sub(pattern, sub, s) # 替换操作
 | 
			
		||||
    if '[1]' in result:
 | 
			
		||||
        result += '\n\n```reference\n' + "\n".join([r for r in result.split('\n') if r.startswith('[')]) + '\n```\n'
 | 
			
		||||
    pattern = r"\^(\d+)\^"  # 匹配^数字^
 | 
			
		||||
    sub = lambda m: "(" + m.group(1) + ")"  # 将匹配到的数字作为替换值
 | 
			
		||||
    result = re.sub(pattern, sub, s)  # 替换操作
 | 
			
		||||
    if "[1]" in result:
 | 
			
		||||
        result += (
 | 
			
		||||
            "\n\n```reference\n"
 | 
			
		||||
            + "\n".join([r for r in result.split("\n") if r.startswith("[")])
 | 
			
		||||
            + "\n```\n"
 | 
			
		||||
        )
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_newbing_out_simple(result):
 | 
			
		||||
    if '[1]' in result:
 | 
			
		||||
        result += '\n\n```reference\n' + "\n".join([r for r in result.split('\n') if r.startswith('[')]) + '\n```\n'
 | 
			
		||||
    if "[1]" in result:
 | 
			
		||||
        result += (
 | 
			
		||||
            "\n\n```reference\n"
 | 
			
		||||
            + "\n".join([r for r in result.split("\n") if r.startswith("[")])
 | 
			
		||||
            + "\n```\n"
 | 
			
		||||
        )
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NewBingHandle(Process):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__(daemon=True)
 | 
			
		||||
@ -51,6 +63,7 @@ class NewBingHandle(Process):
 | 
			
		||||
        try:
 | 
			
		||||
            self.success = False
 | 
			
		||||
            import certifi, httpx, rich
 | 
			
		||||
 | 
			
		||||
            self.info = "依赖检测通过,等待NewBing响应。注意目前不能多人同时调用NewBing接口(有线程锁),否则将导致每个人的NewBing问询历史互相渗透。调用NewBing时,会自动使用已配置的代理。"
 | 
			
		||||
            self.success = True
 | 
			
		||||
        except:
 | 
			
		||||
@ -62,18 +75,19 @@ class NewBingHandle(Process):
 | 
			
		||||
 | 
			
		||||
    async def async_run(self):
 | 
			
		||||
        # 读取配置
 | 
			
		||||
        NEWBING_STYLE = get_conf('NEWBING_STYLE')
 | 
			
		||||
        NEWBING_STYLE = get_conf("NEWBING_STYLE")
 | 
			
		||||
        from request_llms.bridge_all import model_info
 | 
			
		||||
        endpoint = model_info['newbing']['endpoint']
 | 
			
		||||
 | 
			
		||||
        endpoint = model_info["newbing"]["endpoint"]
 | 
			
		||||
        while True:
 | 
			
		||||
            # 等待
 | 
			
		||||
            kwargs = self.child.recv()
 | 
			
		||||
            question=kwargs['query']
 | 
			
		||||
            history=kwargs['history']
 | 
			
		||||
            system_prompt=kwargs['system_prompt']
 | 
			
		||||
            question = kwargs["query"]
 | 
			
		||||
            history = kwargs["history"]
 | 
			
		||||
            system_prompt = kwargs["system_prompt"]
 | 
			
		||||
 | 
			
		||||
            # 是否重置
 | 
			
		||||
            if len(self.local_history) > 0 and len(history)==0:
 | 
			
		||||
            if len(self.local_history) > 0 and len(history) == 0:
 | 
			
		||||
                await self.newbing_model.reset()
 | 
			
		||||
                self.local_history = []
 | 
			
		||||
 | 
			
		||||
@ -81,34 +95,33 @@ class NewBingHandle(Process):
 | 
			
		||||
            prompt = ""
 | 
			
		||||
            if system_prompt not in self.local_history:
 | 
			
		||||
                self.local_history.append(system_prompt)
 | 
			
		||||
                prompt += system_prompt + '\n'
 | 
			
		||||
                prompt += system_prompt + "\n"
 | 
			
		||||
 | 
			
		||||
            # 追加历史
 | 
			
		||||
            for ab in history:
 | 
			
		||||
                a, b = ab
 | 
			
		||||
                if a not in self.local_history:
 | 
			
		||||
                    self.local_history.append(a)
 | 
			
		||||
                    prompt += a + '\n'
 | 
			
		||||
                    prompt += a + "\n"
 | 
			
		||||
 | 
			
		||||
            # 问题
 | 
			
		||||
            prompt += question
 | 
			
		||||
            self.local_history.append(question)
 | 
			
		||||
            print('question:', prompt)
 | 
			
		||||
            print("question:", prompt)
 | 
			
		||||
            # 提交
 | 
			
		||||
            async for final, response in self.newbing_model.ask_stream(
 | 
			
		||||
                prompt=question,
 | 
			
		||||
                conversation_style=NEWBING_STYLE,     # ["creative", "balanced", "precise"]
 | 
			
		||||
                wss_link=endpoint,                    # "wss://sydney.bing.com/sydney/ChatHub"
 | 
			
		||||
                conversation_style=NEWBING_STYLE,  # ["creative", "balanced", "precise"]
 | 
			
		||||
                wss_link=endpoint,  # "wss://sydney.bing.com/sydney/ChatHub"
 | 
			
		||||
            ):
 | 
			
		||||
                if not final:
 | 
			
		||||
                    print(response)
 | 
			
		||||
                    self.child.send(str(response))
 | 
			
		||||
                else:
 | 
			
		||||
                    print('-------- receive final ---------')
 | 
			
		||||
                    self.child.send('[Finish]')
 | 
			
		||||
                    print("-------- receive final ---------")
 | 
			
		||||
                    self.child.send("[Finish]")
 | 
			
		||||
                    # self.local_history.append(response)
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def run(self):
 | 
			
		||||
        """
 | 
			
		||||
        这个函数运行在子进程
 | 
			
		||||
@ -118,32 +131,37 @@ class NewBingHandle(Process):
 | 
			
		||||
        self.local_history = []
 | 
			
		||||
        if (self.newbing_model is None) or (not self.success):
 | 
			
		||||
            # 代理设置
 | 
			
		||||
            proxies, NEWBING_COOKIES = get_conf('proxies', 'NEWBING_COOKIES')
 | 
			
		||||
            proxies, NEWBING_COOKIES = get_conf("proxies", "NEWBING_COOKIES")
 | 
			
		||||
            if proxies is None:
 | 
			
		||||
                self.proxies_https = None
 | 
			
		||||
            else:
 | 
			
		||||
                self.proxies_https = proxies['https']
 | 
			
		||||
                self.proxies_https = proxies["https"]
 | 
			
		||||
 | 
			
		||||
            if (NEWBING_COOKIES is not None) and len(NEWBING_COOKIES) > 100:
 | 
			
		||||
                try:
 | 
			
		||||
                    cookies = json.loads(NEWBING_COOKIES)
 | 
			
		||||
                except:
 | 
			
		||||
                    self.success = False
 | 
			
		||||
                    tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n'
 | 
			
		||||
                    self.child.send(f'[Local Message] NEWBING_COOKIES未填写或有格式错误。')
 | 
			
		||||
                    self.child.send('[Fail]'); self.child.send('[Finish]')
 | 
			
		||||
                    tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
 | 
			
		||||
                    self.child.send(f"[Local Message] NEWBING_COOKIES未填写或有格式错误。")
 | 
			
		||||
                    self.child.send("[Fail]")
 | 
			
		||||
                    self.child.send("[Finish]")
 | 
			
		||||
                    raise RuntimeError(f"NEWBING_COOKIES未填写或有格式错误。")
 | 
			
		||||
            else:
 | 
			
		||||
                cookies = None
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                self.newbing_model = NewbingChatbot(proxy=self.proxies_https, cookies=cookies)
 | 
			
		||||
                self.newbing_model = NewbingChatbot(
 | 
			
		||||
                    proxy=self.proxies_https, cookies=cookies
 | 
			
		||||
                )
 | 
			
		||||
            except:
 | 
			
		||||
                self.success = False
 | 
			
		||||
                tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n'
 | 
			
		||||
                self.child.send(f'[Local Message] 不能加载Newbing组件,请注意Newbing组件已不再维护。{tb_str}')
 | 
			
		||||
                self.child.send('[Fail]')
 | 
			
		||||
                self.child.send('[Finish]')
 | 
			
		||||
                tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
 | 
			
		||||
                self.child.send(
 | 
			
		||||
                    f"[Local Message] 不能加载Newbing组件,请注意Newbing组件已不再维护。{tb_str}"
 | 
			
		||||
                )
 | 
			
		||||
                self.child.send("[Fail]")
 | 
			
		||||
                self.child.send("[Finish]")
 | 
			
		||||
                raise RuntimeError(f"不能加载Newbing组件,请注意Newbing组件已不再维护。")
 | 
			
		||||
 | 
			
		||||
        self.success = True
 | 
			
		||||
@ -151,42 +169,57 @@ class NewBingHandle(Process):
 | 
			
		||||
            # 进入任务等待状态
 | 
			
		||||
            asyncio.run(self.async_run())
 | 
			
		||||
        except Exception:
 | 
			
		||||
            tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n'
 | 
			
		||||
            self.child.send(f'[Local Message] Newbing 请求失败,报错信息如下. 如果是与网络相关的问题,建议更换代理协议(推荐http)或代理节点 {tb_str}.')
 | 
			
		||||
            self.child.send('[Fail]')
 | 
			
		||||
            self.child.send('[Finish]')
 | 
			
		||||
            tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
 | 
			
		||||
            self.child.send(
 | 
			
		||||
                f"[Local Message] Newbing 请求失败,报错信息如下. 如果是与网络相关的问题,建议更换代理协议(推荐http)或代理节点 {tb_str}."
 | 
			
		||||
            )
 | 
			
		||||
            self.child.send("[Fail]")
 | 
			
		||||
            self.child.send("[Finish]")
 | 
			
		||||
 | 
			
		||||
    def stream_chat(self, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        这个函数运行在主进程
 | 
			
		||||
        """
 | 
			
		||||
        self.threadLock.acquire()   # 获取线程锁
 | 
			
		||||
        self.parent.send(kwargs)    # 请求子进程
 | 
			
		||||
        self.threadLock.acquire()  # 获取线程锁
 | 
			
		||||
        self.parent.send(kwargs)  # 请求子进程
 | 
			
		||||
        while True:
 | 
			
		||||
            res = self.parent.recv()                            # 等待newbing回复的片段
 | 
			
		||||
            if res == '[Finish]': break                         # 结束
 | 
			
		||||
            elif res == '[Fail]': self.success = False; break   # 失败
 | 
			
		||||
            else: yield res                                     # newbing回复的片段
 | 
			
		||||
        self.threadLock.release()   # 释放线程锁
 | 
			
		||||
            res = self.parent.recv()  # 等待newbing回复的片段
 | 
			
		||||
            if res == "[Finish]":
 | 
			
		||||
                break  # 结束
 | 
			
		||||
            elif res == "[Fail]":
 | 
			
		||||
                self.success = False
 | 
			
		||||
                break  # 失败
 | 
			
		||||
            else:
 | 
			
		||||
                yield res  # newbing回复的片段
 | 
			
		||||
        self.threadLock.release()  # 释放线程锁
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
第三部分:主进程统一调用函数接口
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
global newbingfree_handle
 | 
			
		||||
newbingfree_handle = None
 | 
			
		||||
 | 
			
		||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
 | 
			
		||||
 | 
			
		||||
def predict_no_ui_long_connection(
 | 
			
		||||
    inputs,
 | 
			
		||||
    llm_kwargs,
 | 
			
		||||
    history=[],
 | 
			
		||||
    sys_prompt="",
 | 
			
		||||
    observe_window=[],
 | 
			
		||||
    console_slience=False,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
        多线程方法
 | 
			
		||||
        函数的说明请见 request_llms/bridge_all.py
 | 
			
		||||
    多线程方法
 | 
			
		||||
    函数的说明请见 request_llms/bridge_all.py
 | 
			
		||||
    """
 | 
			
		||||
    global newbingfree_handle
 | 
			
		||||
    if (newbingfree_handle is None) or (not newbingfree_handle.success):
 | 
			
		||||
        newbingfree_handle = NewBingHandle()
 | 
			
		||||
        if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + newbingfree_handle.info
 | 
			
		||||
        if len(observe_window) >= 1:
 | 
			
		||||
            observe_window[0] = load_message + "\n\n" + newbingfree_handle.info
 | 
			
		||||
        if not newbingfree_handle.success:
 | 
			
		||||
            error = newbingfree_handle.info
 | 
			
		||||
            newbingfree_handle = None
 | 
			
		||||
@ -194,23 +227,42 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
 | 
			
		||||
 | 
			
		||||
    # 没有 sys_prompt 接口,因此把prompt加入 history
 | 
			
		||||
    history_feedin = []
 | 
			
		||||
    for i in range(len(history)//2):
 | 
			
		||||
        history_feedin.append([history[2*i], history[2*i+1]] )
 | 
			
		||||
    for i in range(len(history) // 2):
 | 
			
		||||
        history_feedin.append([history[2 * i], history[2 * i + 1]])
 | 
			
		||||
 | 
			
		||||
    watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
 | 
			
		||||
    watch_dog_patience = 5  # 看门狗 (watchdog) 的耐心, 设置5秒即可
 | 
			
		||||
    response = ""
 | 
			
		||||
    if len(observe_window) >= 1: observe_window[0] = "[Local Message] 等待NewBing响应中 ..."
 | 
			
		||||
    for response in newbingfree_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=sys_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
 | 
			
		||||
        if len(observe_window) >= 1:  observe_window[0] = preprocess_newbing_out_simple(response)
 | 
			
		||||
    if len(observe_window) >= 1:
 | 
			
		||||
        observe_window[0] = "[Local Message] 等待NewBing响应中 ..."
 | 
			
		||||
    for response in newbingfree_handle.stream_chat(
 | 
			
		||||
        query=inputs,
 | 
			
		||||
        history=history_feedin,
 | 
			
		||||
        system_prompt=sys_prompt,
 | 
			
		||||
        max_length=llm_kwargs["max_length"],
 | 
			
		||||
        top_p=llm_kwargs["top_p"],
 | 
			
		||||
        temperature=llm_kwargs["temperature"],
 | 
			
		||||
    ):
 | 
			
		||||
        if len(observe_window) >= 1:
 | 
			
		||||
            observe_window[0] = preprocess_newbing_out_simple(response)
 | 
			
		||||
        if len(observe_window) >= 2:
 | 
			
		||||
            if (time.time()-observe_window[1]) > watch_dog_patience:
 | 
			
		||||
            if (time.time() - observe_window[1]) > watch_dog_patience:
 | 
			
		||||
                raise RuntimeError("程序终止。")
 | 
			
		||||
    return preprocess_newbing_out_simple(response)
 | 
			
		||||
 | 
			
		||||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
 | 
			
		||||
 | 
			
		||||
def predict(
 | 
			
		||||
    inputs,
 | 
			
		||||
    llm_kwargs,
 | 
			
		||||
    plugin_kwargs,
 | 
			
		||||
    chatbot,
 | 
			
		||||
    history=[],
 | 
			
		||||
    system_prompt="",
 | 
			
		||||
    stream=True,
 | 
			
		||||
    additional_fn=None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
        单线程方法
 | 
			
		||||
        函数的说明请见 request_llms/bridge_all.py
 | 
			
		||||
    单线程方法
 | 
			
		||||
    函数的说明请见 request_llms/bridge_all.py
 | 
			
		||||
    """
 | 
			
		||||
    chatbot.append((inputs, "[Local Message] 等待NewBing响应中 ..."))
 | 
			
		||||
 | 
			
		||||
@ -225,21 +277,35 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
 | 
			
		||||
 | 
			
		||||
    if additional_fn is not None:
 | 
			
		||||
        from core_functional import handle_core_functionality
 | 
			
		||||
        inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
 | 
			
		||||
 | 
			
		||||
        inputs, history = handle_core_functionality(
 | 
			
		||||
            additional_fn, inputs, history, chatbot
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    history_feedin = []
 | 
			
		||||
    for i in range(len(history)//2):
 | 
			
		||||
        history_feedin.append([history[2*i], history[2*i+1]] )
 | 
			
		||||
    for i in range(len(history) // 2):
 | 
			
		||||
        history_feedin.append([history[2 * i], history[2 * i + 1]])
 | 
			
		||||
 | 
			
		||||
    chatbot[-1] = (inputs, "[Local Message] 等待NewBing响应中 ...")
 | 
			
		||||
    response = "[Local Message] 等待NewBing响应中 ..."
 | 
			
		||||
    yield from update_ui(chatbot=chatbot, history=history, msg="NewBing响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。")
 | 
			
		||||
    for response in newbingfree_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=system_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
 | 
			
		||||
    yield from update_ui(
 | 
			
		||||
        chatbot=chatbot, history=history, msg="NewBing响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。"
 | 
			
		||||
    )
 | 
			
		||||
    for response in newbingfree_handle.stream_chat(
 | 
			
		||||
        query=inputs,
 | 
			
		||||
        history=history_feedin,
 | 
			
		||||
        system_prompt=system_prompt,
 | 
			
		||||
        max_length=llm_kwargs["max_length"],
 | 
			
		||||
        top_p=llm_kwargs["top_p"],
 | 
			
		||||
        temperature=llm_kwargs["temperature"],
 | 
			
		||||
    ):
 | 
			
		||||
        chatbot[-1] = (inputs, preprocess_newbing_out(response))
 | 
			
		||||
        yield from update_ui(chatbot=chatbot, history=history, msg="NewBing响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。")
 | 
			
		||||
    if response == "[Local Message] 等待NewBing响应中 ...": response = "[Local Message] NewBing响应异常,请刷新界面重试 ..."
 | 
			
		||||
        yield from update_ui(
 | 
			
		||||
            chatbot=chatbot, history=history, msg="NewBing响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。"
 | 
			
		||||
        )
 | 
			
		||||
    if response == "[Local Message] 等待NewBing响应中 ...":
 | 
			
		||||
        response = "[Local Message] NewBing响应异常,请刷新界面重试 ..."
 | 
			
		||||
    history.extend([inputs, response])
 | 
			
		||||
    logging.info(f'[raw_input] {inputs}')
 | 
			
		||||
    logging.info(f'[response] {response}')
 | 
			
		||||
    logging.info(f"[raw_input] {inputs}")
 | 
			
		||||
    logging.info(f"[response] {response}")
 | 
			
		||||
    yield from update_ui(chatbot=chatbot, history=history, msg="完成全部响应,请提交新问题。")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -7,14 +7,15 @@ import logging
 | 
			
		||||
import time
 | 
			
		||||
from toolbox import get_conf
 | 
			
		||||
import asyncio
 | 
			
		||||
 | 
			
		||||
load_message = "正在加载Claude组件,请稍候..."
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    """
 | 
			
		||||
    ========================================================================
 | 
			
		||||
    =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
    第一部分:Slack API Client
 | 
			
		||||
    https://github.com/yokonsan/claude-in-slack-api
 | 
			
		||||
    ========================================================================
 | 
			
		||||
    =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    from slack_sdk.errors import SlackApiError
 | 
			
		||||
@ -23,20 +24,23 @@ try:
 | 
			
		||||
    class SlackClient(AsyncWebClient):
 | 
			
		||||
        """SlackClient类用于与Slack API进行交互,实现消息发送、接收等功能。
 | 
			
		||||
 | 
			
		||||
            属性:
 | 
			
		||||
            - CHANNEL_ID:str类型,表示频道ID。
 | 
			
		||||
        属性:
 | 
			
		||||
        - CHANNEL_ID:str类型,表示频道ID。
 | 
			
		||||
 | 
			
		||||
            方法:
 | 
			
		||||
            - open_channel():异步方法。通过调用conversations_open方法打开一个频道,并将返回的频道ID保存在属性CHANNEL_ID中。
 | 
			
		||||
            - chat(text: str):异步方法。向已打开的频道发送一条文本消息。
 | 
			
		||||
            - get_slack_messages():异步方法。获取已打开频道的最新消息并返回消息列表,目前不支持历史消息查询。
 | 
			
		||||
            - get_reply():异步方法。循环监听已打开频道的消息,如果收到"Typing…_"结尾的消息说明Claude还在继续输出,否则结束循环。
 | 
			
		||||
        方法:
 | 
			
		||||
        - open_channel():异步方法。通过调用conversations_open方法打开一个频道,并将返回的频道ID保存在属性CHANNEL_ID中。
 | 
			
		||||
        - chat(text: str):异步方法。向已打开的频道发送一条文本消息。
 | 
			
		||||
        - get_slack_messages():异步方法。获取已打开频道的最新消息并返回消息列表,目前不支持历史消息查询。
 | 
			
		||||
        - get_reply():异步方法。循环监听已打开频道的消息,如果收到"Typing…_"结尾的消息说明Claude还在继续输出,否则结束循环。
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        CHANNEL_ID = None
 | 
			
		||||
 | 
			
		||||
        async def open_channel(self):
 | 
			
		||||
            response = await self.conversations_open(users=get_conf('SLACK_CLAUDE_BOT_ID'))
 | 
			
		||||
            response = await self.conversations_open(
 | 
			
		||||
                users=get_conf("SLACK_CLAUDE_BOT_ID")
 | 
			
		||||
            )
 | 
			
		||||
            self.CHANNEL_ID = response["channel"]["id"]
 | 
			
		||||
 | 
			
		||||
        async def chat(self, text):
 | 
			
		||||
@ -49,9 +53,14 @@ try:
 | 
			
		||||
        async def get_slack_messages(self):
 | 
			
		||||
            try:
 | 
			
		||||
                # TODO:暂时不支持历史消息,因为在同一个频道里存在多人使用时历史消息渗透问题
 | 
			
		||||
                resp = await self.conversations_history(channel=self.CHANNEL_ID, oldest=self.LAST_TS, limit=1)
 | 
			
		||||
                msg = [msg for msg in resp["messages"]
 | 
			
		||||
                    if msg.get("user") == get_conf('SLACK_CLAUDE_BOT_ID')]
 | 
			
		||||
                resp = await self.conversations_history(
 | 
			
		||||
                    channel=self.CHANNEL_ID, oldest=self.LAST_TS, limit=1
 | 
			
		||||
                )
 | 
			
		||||
                msg = [
 | 
			
		||||
                    msg
 | 
			
		||||
                    for msg in resp["messages"]
 | 
			
		||||
                    if msg.get("user") == get_conf("SLACK_CLAUDE_BOT_ID")
 | 
			
		||||
                ]
 | 
			
		||||
                return msg
 | 
			
		||||
            except (SlackApiError, KeyError) as e:
 | 
			
		||||
                raise RuntimeError(f"获取Slack消息失败。")
 | 
			
		||||
@ -69,13 +78,14 @@ try:
 | 
			
		||||
                else:
 | 
			
		||||
                    yield True, msg["text"]
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
except:
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
第二部分:子进程Worker(调用主体)
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -96,6 +106,7 @@ class ClaudeHandle(Process):
 | 
			
		||||
        try:
 | 
			
		||||
            self.success = False
 | 
			
		||||
            import slack_sdk
 | 
			
		||||
 | 
			
		||||
            self.info = "依赖检测通过,等待Claude响应。注意目前不能多人同时调用Claude接口(有线程锁),否则将导致每个人的Claude问询历史互相渗透。调用Claude时,会自动使用已配置的代理。"
 | 
			
		||||
            self.success = True
 | 
			
		||||
        except:
 | 
			
		||||
@ -110,15 +121,15 @@ class ClaudeHandle(Process):
 | 
			
		||||
        while True:
 | 
			
		||||
            # 等待
 | 
			
		||||
            kwargs = self.child.recv()
 | 
			
		||||
            question = kwargs['query']
 | 
			
		||||
            history = kwargs['history']
 | 
			
		||||
            question = kwargs["query"]
 | 
			
		||||
            history = kwargs["history"]
 | 
			
		||||
 | 
			
		||||
            # 开始问问题
 | 
			
		||||
            prompt = ""
 | 
			
		||||
 | 
			
		||||
            # 问题
 | 
			
		||||
            prompt += question
 | 
			
		||||
            print('question:', prompt)
 | 
			
		||||
            print("question:", prompt)
 | 
			
		||||
 | 
			
		||||
            # 提交
 | 
			
		||||
            await self.claude_model.chat(prompt)
 | 
			
		||||
@ -131,11 +142,15 @@ class ClaudeHandle(Process):
 | 
			
		||||
                else:
 | 
			
		||||
                    # 防止丢失最后一条消息
 | 
			
		||||
                    slack_msgs = await self.claude_model.get_slack_messages()
 | 
			
		||||
                    last_msg = slack_msgs[-1]["text"] if slack_msgs and len(slack_msgs) > 0 else ""
 | 
			
		||||
                    last_msg = (
 | 
			
		||||
                        slack_msgs[-1]["text"]
 | 
			
		||||
                        if slack_msgs and len(slack_msgs) > 0
 | 
			
		||||
                        else ""
 | 
			
		||||
                    )
 | 
			
		||||
                    if last_msg:
 | 
			
		||||
                        self.child.send(last_msg)
 | 
			
		||||
                    print('-------- receive final ---------')
 | 
			
		||||
                    self.child.send('[Finish]')
 | 
			
		||||
                    print("-------- receive final ---------")
 | 
			
		||||
                    self.child.send("[Finish]")
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
        """
 | 
			
		||||
@ -146,22 +161,24 @@ class ClaudeHandle(Process):
 | 
			
		||||
        self.local_history = []
 | 
			
		||||
        if (self.claude_model is None) or (not self.success):
 | 
			
		||||
            # 代理设置
 | 
			
		||||
            proxies = get_conf('proxies')
 | 
			
		||||
            proxies = get_conf("proxies")
 | 
			
		||||
            if proxies is None:
 | 
			
		||||
                self.proxies_https = None
 | 
			
		||||
            else:
 | 
			
		||||
                self.proxies_https = proxies['https']
 | 
			
		||||
                self.proxies_https = proxies["https"]
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                SLACK_CLAUDE_USER_TOKEN = get_conf('SLACK_CLAUDE_USER_TOKEN')
 | 
			
		||||
                self.claude_model = SlackClient(token=SLACK_CLAUDE_USER_TOKEN, proxy=self.proxies_https)
 | 
			
		||||
                print('Claude组件初始化成功。')
 | 
			
		||||
                SLACK_CLAUDE_USER_TOKEN = get_conf("SLACK_CLAUDE_USER_TOKEN")
 | 
			
		||||
                self.claude_model = SlackClient(
 | 
			
		||||
                    token=SLACK_CLAUDE_USER_TOKEN, proxy=self.proxies_https
 | 
			
		||||
                )
 | 
			
		||||
                print("Claude组件初始化成功。")
 | 
			
		||||
            except:
 | 
			
		||||
                self.success = False
 | 
			
		||||
                tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n'
 | 
			
		||||
                self.child.send(f'[Local Message] 不能加载Claude组件。{tb_str}')
 | 
			
		||||
                self.child.send('[Fail]')
 | 
			
		||||
                self.child.send('[Finish]')
 | 
			
		||||
                tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
 | 
			
		||||
                self.child.send(f"[Local Message] 不能加载Claude组件。{tb_str}")
 | 
			
		||||
                self.child.send("[Fail]")
 | 
			
		||||
                self.child.send("[Finish]")
 | 
			
		||||
                raise RuntimeError(f"不能加载Claude组件。")
 | 
			
		||||
 | 
			
		||||
        self.success = True
 | 
			
		||||
@ -169,42 +186,49 @@ class ClaudeHandle(Process):
 | 
			
		||||
            # 进入任务等待状态
 | 
			
		||||
            asyncio.run(self.async_run())
 | 
			
		||||
        except Exception:
 | 
			
		||||
            tb_str = '\n```\n' + trimmed_format_exc() + '\n```\n'
 | 
			
		||||
            self.child.send(f'[Local Message] Claude失败 {tb_str}.')
 | 
			
		||||
            self.child.send('[Fail]')
 | 
			
		||||
            self.child.send('[Finish]')
 | 
			
		||||
            tb_str = "\n```\n" + trimmed_format_exc() + "\n```\n"
 | 
			
		||||
            self.child.send(f"[Local Message] Claude失败 {tb_str}.")
 | 
			
		||||
            self.child.send("[Fail]")
 | 
			
		||||
            self.child.send("[Finish]")
 | 
			
		||||
 | 
			
		||||
    def stream_chat(self, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        这个函数运行在主进程
 | 
			
		||||
        """
 | 
			
		||||
        self.threadLock.acquire()
 | 
			
		||||
        self.parent.send(kwargs)    # 发送请求到子进程
 | 
			
		||||
        self.parent.send(kwargs)  # 发送请求到子进程
 | 
			
		||||
        while True:
 | 
			
		||||
            res = self.parent.recv()    # 等待Claude回复的片段
 | 
			
		||||
            if res == '[Finish]':
 | 
			
		||||
                break       # 结束
 | 
			
		||||
            elif res == '[Fail]':
 | 
			
		||||
            res = self.parent.recv()  # 等待Claude回复的片段
 | 
			
		||||
            if res == "[Finish]":
 | 
			
		||||
                break  # 结束
 | 
			
		||||
            elif res == "[Fail]":
 | 
			
		||||
                self.success = False
 | 
			
		||||
                break
 | 
			
		||||
            else:
 | 
			
		||||
                yield res   # Claude回复的片段
 | 
			
		||||
                yield res  # Claude回复的片段
 | 
			
		||||
        self.threadLock.release()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
第三部分:主进程统一调用函数接口
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
global claude_handle
 | 
			
		||||
claude_handle = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False):
 | 
			
		||||
def predict_no_ui_long_connection(
 | 
			
		||||
    inputs,
 | 
			
		||||
    llm_kwargs,
 | 
			
		||||
    history=[],
 | 
			
		||||
    sys_prompt="",
 | 
			
		||||
    observe_window=None,
 | 
			
		||||
    console_slience=False,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
        多线程方法
 | 
			
		||||
        函数的说明请见 request_llms/bridge_all.py
 | 
			
		||||
    多线程方法
 | 
			
		||||
    函数的说明请见 request_llms/bridge_all.py
 | 
			
		||||
    """
 | 
			
		||||
    global claude_handle
 | 
			
		||||
    if (claude_handle is None) or (not claude_handle.success):
 | 
			
		||||
@ -217,24 +241,40 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
 | 
			
		||||
 | 
			
		||||
    # 没有 sys_prompt 接口,因此把prompt加入 history
 | 
			
		||||
    history_feedin = []
 | 
			
		||||
    for i in range(len(history)//2):
 | 
			
		||||
        history_feedin.append([history[2*i], history[2*i+1]])
 | 
			
		||||
    for i in range(len(history) // 2):
 | 
			
		||||
        history_feedin.append([history[2 * i], history[2 * i + 1]])
 | 
			
		||||
 | 
			
		||||
    watch_dog_patience = 5  # 看门狗 (watchdog) 的耐心, 设置5秒即可
 | 
			
		||||
    response = ""
 | 
			
		||||
    observe_window[0] = "[Local Message] 等待Claude响应中 ..."
 | 
			
		||||
    for response in claude_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=sys_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
 | 
			
		||||
    for response in claude_handle.stream_chat(
 | 
			
		||||
        query=inputs,
 | 
			
		||||
        history=history_feedin,
 | 
			
		||||
        system_prompt=sys_prompt,
 | 
			
		||||
        max_length=llm_kwargs["max_length"],
 | 
			
		||||
        top_p=llm_kwargs["top_p"],
 | 
			
		||||
        temperature=llm_kwargs["temperature"],
 | 
			
		||||
    ):
 | 
			
		||||
        observe_window[0] = preprocess_newbing_out_simple(response)
 | 
			
		||||
        if len(observe_window) >= 2:
 | 
			
		||||
            if (time.time()-observe_window[1]) > watch_dog_patience:
 | 
			
		||||
            if (time.time() - observe_window[1]) > watch_dog_patience:
 | 
			
		||||
                raise RuntimeError("程序终止。")
 | 
			
		||||
    return preprocess_newbing_out_simple(response)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
 | 
			
		||||
def predict(
 | 
			
		||||
    inputs,
 | 
			
		||||
    llm_kwargs,
 | 
			
		||||
    plugin_kwargs,
 | 
			
		||||
    chatbot,
 | 
			
		||||
    history=[],
 | 
			
		||||
    system_prompt="",
 | 
			
		||||
    stream=True,
 | 
			
		||||
    additional_fn=None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
        单线程方法
 | 
			
		||||
        函数的说明请见 request_llms/bridge_all.py
 | 
			
		||||
    单线程方法
 | 
			
		||||
    函数的说明请见 request_llms/bridge_all.py
 | 
			
		||||
    """
 | 
			
		||||
    chatbot.append((inputs, "[Local Message] 等待Claude响应中 ..."))
 | 
			
		||||
 | 
			
		||||
@ -249,21 +289,30 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
 | 
			
		||||
 | 
			
		||||
    if additional_fn is not None:
 | 
			
		||||
        from core_functional import handle_core_functionality
 | 
			
		||||
        inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
 | 
			
		||||
 | 
			
		||||
        inputs, history = handle_core_functionality(
 | 
			
		||||
            additional_fn, inputs, history, chatbot
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    history_feedin = []
 | 
			
		||||
    for i in range(len(history)//2):
 | 
			
		||||
        history_feedin.append([history[2*i], history[2*i+1]])
 | 
			
		||||
    for i in range(len(history) // 2):
 | 
			
		||||
        history_feedin.append([history[2 * i], history[2 * i + 1]])
 | 
			
		||||
 | 
			
		||||
    chatbot[-1] = (inputs, "[Local Message] 等待Claude响应中 ...")
 | 
			
		||||
    response = "[Local Message] 等待Claude响应中 ..."
 | 
			
		||||
    yield from update_ui(chatbot=chatbot, history=history, msg="Claude响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。")
 | 
			
		||||
    for response in claude_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=system_prompt):
 | 
			
		||||
    yield from update_ui(
 | 
			
		||||
        chatbot=chatbot, history=history, msg="Claude响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。"
 | 
			
		||||
    )
 | 
			
		||||
    for response in claude_handle.stream_chat(
 | 
			
		||||
        query=inputs, history=history_feedin, system_prompt=system_prompt
 | 
			
		||||
    ):
 | 
			
		||||
        chatbot[-1] = (inputs, preprocess_newbing_out(response))
 | 
			
		||||
        yield from update_ui(chatbot=chatbot, history=history, msg="Claude响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。")
 | 
			
		||||
        yield from update_ui(
 | 
			
		||||
            chatbot=chatbot, history=history, msg="Claude响应缓慢,尚未完成全部响应,请耐心完成后再提交新问题。"
 | 
			
		||||
        )
 | 
			
		||||
    if response == "[Local Message] 等待Claude响应中 ...":
 | 
			
		||||
        response = "[Local Message] Claude响应异常,请刷新界面重试 ..."
 | 
			
		||||
    history.extend([inputs, response])
 | 
			
		||||
    logging.info(f'[raw_input] {inputs}')
 | 
			
		||||
    logging.info(f'[response] {response}')
 | 
			
		||||
    logging.info(f"[raw_input] {inputs}")
 | 
			
		||||
    logging.info(f"[response] {response}")
 | 
			
		||||
    yield from update_ui(chatbot=chatbot, history=history, msg="完成全部响应,请提交新问题。")
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@ from toolbox import get_conf, encode_image, get_pictures_list
 | 
			
		||||
proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
第五部分 一些文件处理方法
 | 
			
		||||
files_filter_handler 根据type过滤文件
 | 
			
		||||
input_encode_handler 提取input中的文件,并解析
 | 
			
		||||
@ -21,6 +21,7 @@ link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
 | 
			
		||||
html_view_blank 超链接
 | 
			
		||||
html_local_file 本地文件取相对路径
 | 
			
		||||
to_markdown_tabs 文件list 转换为 md tab
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,8 @@
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
第一部分:来自EdgeGPT.py
 | 
			
		||||
https://github.com/acheong08/EdgeGPT
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
"""
 | 
			
		||||
Main.py
 | 
			
		||||
@ -196,9 +196,9 @@ class _ChatHubRequest:
 | 
			
		||||
        self,
 | 
			
		||||
        prompt: str,
 | 
			
		||||
        conversation_style: CONVERSATION_STYLE_TYPE,
 | 
			
		||||
        options = None,
 | 
			
		||||
        webpage_context = None,
 | 
			
		||||
        search_result = False,
 | 
			
		||||
        options=None,
 | 
			
		||||
        webpage_context=None,
 | 
			
		||||
        search_result=False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Updates request object
 | 
			
		||||
@ -294,9 +294,9 @@ class _Conversation:
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        proxy = None,
 | 
			
		||||
        async_mode = False,
 | 
			
		||||
        cookies = None,
 | 
			
		||||
        proxy=None,
 | 
			
		||||
        async_mode=False,
 | 
			
		||||
        cookies=None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        if async_mode:
 | 
			
		||||
            return
 | 
			
		||||
@ -350,8 +350,8 @@ class _Conversation:
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    async def create(
 | 
			
		||||
        proxy = None,
 | 
			
		||||
        cookies = None,
 | 
			
		||||
        proxy=None,
 | 
			
		||||
        cookies=None,
 | 
			
		||||
    ):
 | 
			
		||||
        self = _Conversation(async_mode=True)
 | 
			
		||||
        self.struct = {
 | 
			
		||||
@ -418,8 +418,8 @@ class _ChatHub:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        conversation: _Conversation,
 | 
			
		||||
        proxy = None,
 | 
			
		||||
        cookies = None,
 | 
			
		||||
        proxy=None,
 | 
			
		||||
        cookies=None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.session = None
 | 
			
		||||
        self.wss = None
 | 
			
		||||
@ -441,7 +441,7 @@ class _ChatHub:
 | 
			
		||||
        conversation_style: CONVERSATION_STYLE_TYPE = None,
 | 
			
		||||
        raw: bool = False,
 | 
			
		||||
        options: dict = None,
 | 
			
		||||
        webpage_context = None,
 | 
			
		||||
        webpage_context=None,
 | 
			
		||||
        search_result: bool = False,
 | 
			
		||||
    ) -> Generator[str, None, None]:
 | 
			
		||||
        """
 | 
			
		||||
@ -452,9 +452,11 @@ class _ChatHub:
 | 
			
		||||
            ws_cookies = []
 | 
			
		||||
            for cookie in self.cookies:
 | 
			
		||||
                ws_cookies.append(f"{cookie['name']}={cookie['value']}")
 | 
			
		||||
            req_header.update({
 | 
			
		||||
                'Cookie': ';'.join(ws_cookies),
 | 
			
		||||
            })
 | 
			
		||||
            req_header.update(
 | 
			
		||||
                {
 | 
			
		||||
                    "Cookie": ";".join(ws_cookies),
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        timeout = aiohttp.ClientTimeout(total=30)
 | 
			
		||||
        self.session = aiohttp.ClientSession(timeout=timeout)
 | 
			
		||||
@ -521,7 +523,7 @@ class _ChatHub:
 | 
			
		||||
            msg = await self.wss.receive()
 | 
			
		||||
            try:
 | 
			
		||||
                objects = msg.data.split(DELIMITER)
 | 
			
		||||
            except :
 | 
			
		||||
            except:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            for obj in objects:
 | 
			
		||||
@ -624,8 +626,8 @@ class Chatbot:
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        proxy = None,
 | 
			
		||||
        cookies = None,
 | 
			
		||||
        proxy=None,
 | 
			
		||||
        cookies=None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.proxy = proxy
 | 
			
		||||
        self.chat_hub: _ChatHub = _ChatHub(
 | 
			
		||||
@ -636,8 +638,8 @@ class Chatbot:
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    async def create(
 | 
			
		||||
        proxy = None,
 | 
			
		||||
        cookies = None,
 | 
			
		||||
        proxy=None,
 | 
			
		||||
        cookies=None,
 | 
			
		||||
    ):
 | 
			
		||||
        self = Chatbot.__new__(Chatbot)
 | 
			
		||||
        self.proxy = proxy
 | 
			
		||||
@ -654,7 +656,7 @@ class Chatbot:
 | 
			
		||||
        wss_link: str = "wss://sydney.bing.com/sydney/ChatHub",
 | 
			
		||||
        conversation_style: CONVERSATION_STYLE_TYPE = None,
 | 
			
		||||
        options: dict = None,
 | 
			
		||||
        webpage_context = None,
 | 
			
		||||
        webpage_context=None,
 | 
			
		||||
        search_result: bool = False,
 | 
			
		||||
    ) -> dict:
 | 
			
		||||
        """
 | 
			
		||||
@ -680,7 +682,7 @@ class Chatbot:
 | 
			
		||||
        conversation_style: CONVERSATION_STYLE_TYPE = None,
 | 
			
		||||
        raw: bool = False,
 | 
			
		||||
        options: dict = None,
 | 
			
		||||
        webpage_context = None,
 | 
			
		||||
        webpage_context=None,
 | 
			
		||||
        search_result: bool = False,
 | 
			
		||||
    ) -> Generator[str, None, None]:
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ import markdown
 | 
			
		||||
import re
 | 
			
		||||
import os
 | 
			
		||||
import math
 | 
			
		||||
from textwrap import dedent
 | 
			
		||||
from latex2mathml.converter import convert as tex2mathml
 | 
			
		||||
from functools import wraps, lru_cache
 | 
			
		||||
from shared_utils.config_loader import get_conf as get_conf
 | 
			
		||||
@ -32,6 +33,147 @@ def text_divide_paragraph(text):
 | 
			
		||||
        text = "</br>".join(lines)
 | 
			
		||||
        return pre + text + suf
 | 
			
		||||
 | 
			
		||||
def tex2mathml_catch_exception(content, *args, **kwargs):
 | 
			
		||||
    try:
 | 
			
		||||
        content = tex2mathml(content, *args, **kwargs)
 | 
			
		||||
    except:
 | 
			
		||||
        content = content
 | 
			
		||||
    return content
 | 
			
		||||
 | 
			
		||||
def replace_math_no_render(match):
 | 
			
		||||
    content = match.group(1)
 | 
			
		||||
    if 'mode=display' in match.group(0):
 | 
			
		||||
        content = content.replace('\n', '</br>')
 | 
			
		||||
        return f"<font color=\"#00FF00\">$$</font><font color=\"#FF00FF\">{content}</font><font color=\"#00FF00\">$$</font>"
 | 
			
		||||
    else:
 | 
			
		||||
        return f"<font color=\"#00FF00\">$</font><font color=\"#FF00FF\">{content}</font><font color=\"#00FF00\">$</font>"
 | 
			
		||||
 | 
			
		||||
def replace_math_render(match):
 | 
			
		||||
    content = match.group(1)
 | 
			
		||||
    if 'mode=display' in match.group(0):
 | 
			
		||||
        if '\\begin{aligned}' in content:
 | 
			
		||||
            content = content.replace('\\begin{aligned}', '\\begin{array}')
 | 
			
		||||
            content = content.replace('\\end{aligned}', '\\end{array}')
 | 
			
		||||
            content = content.replace('&', ' ')
 | 
			
		||||
        content = tex2mathml_catch_exception(content, display="block")
 | 
			
		||||
        return content
 | 
			
		||||
    else:
 | 
			
		||||
        return tex2mathml_catch_exception(content)
 | 
			
		||||
 | 
			
		||||
def markdown_bug_hunt(content):
 | 
			
		||||
    """
 | 
			
		||||
    解决一个mdx_math的bug(单$包裹begin命令时多余<script>)
 | 
			
		||||
    """
 | 
			
		||||
    content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">',
 | 
			
		||||
                                '<script type="math/tex; mode=display">')
 | 
			
		||||
    content = content.replace('</script>\n</script>', '</script>')
 | 
			
		||||
    return content
 | 
			
		||||
 | 
			
		||||
def is_equation(txt):
 | 
			
		||||
    """
 | 
			
		||||
    判定是否为公式 | 测试1 写出洛伦兹定律,使用tex格式公式 测试2 给出柯西不等式,使用latex格式 测试3 写出麦克斯韦方程组
 | 
			
		||||
    """
 | 
			
		||||
    if '```' in txt and '```reference' not in txt: return False
 | 
			
		||||
    if '$' not in txt and '\\[' not in txt: return False
 | 
			
		||||
    mathpatterns = {
 | 
			
		||||
        r'(?<!\\|\$)(\$)([^\$]+)(\$)': {'allow_multi_lines': False},                       #  $...$
 | 
			
		||||
        r'(?<!\\)(\$\$)([^\$]+)(\$\$)': {'allow_multi_lines': True},                       # $$...$$
 | 
			
		||||
        r'(?<!\\)(\\\[)(.+?)(\\\])': {'allow_multi_lines': False},                         # \[...\]
 | 
			
		||||
        # r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False},                       # \(...\)
 | 
			
		||||
        # r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True},  # \begin...\end
 | 
			
		||||
        # r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False},                       # $`...`$
 | 
			
		||||
    }
 | 
			
		||||
    matches = []
 | 
			
		||||
    for pattern, property in mathpatterns.items():
 | 
			
		||||
        flags = re.ASCII | re.DOTALL if property['allow_multi_lines'] else re.ASCII
 | 
			
		||||
        matches.extend(re.findall(pattern, txt, flags))
 | 
			
		||||
    if len(matches) == 0: return False
 | 
			
		||||
    contain_any_eq = False
 | 
			
		||||
    illegal_pattern = re.compile(r'[^\x00-\x7F]|echo')
 | 
			
		||||
    for match in matches:
 | 
			
		||||
        if len(match) != 3: return False
 | 
			
		||||
        eq_canidate = match[1]
 | 
			
		||||
        if illegal_pattern.search(eq_canidate):
 | 
			
		||||
            return False
 | 
			
		||||
        else:
 | 
			
		||||
            contain_any_eq = True
 | 
			
		||||
    return contain_any_eq
 | 
			
		||||
 | 
			
		||||
def fix_markdown_indent(txt):
 | 
			
		||||
    # fix markdown indent
 | 
			
		||||
    if (' - ' not in txt) or ('. ' not in txt):
 | 
			
		||||
        # do not need to fix, fast escape
 | 
			
		||||
        return txt
 | 
			
		||||
    # walk through the lines and fix non-standard indentation
 | 
			
		||||
    lines = txt.split("\n")
 | 
			
		||||
    pattern = re.compile(r'^\s+-')
 | 
			
		||||
    activated = False
 | 
			
		||||
    for i, line in enumerate(lines):
 | 
			
		||||
        if line.startswith('- ') or line.startswith('1. '):
 | 
			
		||||
            activated = True
 | 
			
		||||
        if activated and pattern.match(line):
 | 
			
		||||
            stripped_string = line.lstrip()
 | 
			
		||||
            num_spaces = len(line) - len(stripped_string)
 | 
			
		||||
            if (num_spaces % 4) == 3:
 | 
			
		||||
                num_spaces_should_be = math.ceil(num_spaces / 4) * 4
 | 
			
		||||
                lines[i] = ' ' * num_spaces_should_be + stripped_string
 | 
			
		||||
    return '\n'.join(lines)
 | 
			
		||||
 | 
			
		||||
FENCED_BLOCK_RE = re.compile(
 | 
			
		||||
    dedent(r'''
 | 
			
		||||
        (?P<fence>^[ \t]*(?:~{3,}|`{3,}))[ ]*                      # opening fence
 | 
			
		||||
        ((\{(?P<attrs>[^\}\n]*)\})|                              # (optional {attrs} or
 | 
			
		||||
        (\.?(?P<lang>[\w#.+-]*)[ ]*)?                            # optional (.)lang
 | 
			
		||||
        (hl_lines=(?P<quot>"|')(?P<hl_lines>.*?)(?P=quot)[ ]*)?) # optional hl_lines)
 | 
			
		||||
        \n                                                       # newline (end of opening fence)
 | 
			
		||||
        (?P<code>.*?)(?<=\n)                                     # the code block
 | 
			
		||||
        (?P=fence)[ ]*$                                          # closing fence
 | 
			
		||||
    '''),
 | 
			
		||||
    re.MULTILINE | re.DOTALL | re.VERBOSE
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
def get_line_range(re_match_obj, txt):
 | 
			
		||||
    start_pos, end_pos = re_match_obj.regs[0]
 | 
			
		||||
    num_newlines_before = txt[:start_pos+1].count('\n')
 | 
			
		||||
    line_start = num_newlines_before
 | 
			
		||||
    line_end = num_newlines_before + txt[start_pos:end_pos].count('\n')+1
 | 
			
		||||
    return line_start, line_end
 | 
			
		||||
 | 
			
		||||
def fix_code_segment_indent(txt):
 | 
			
		||||
    lines = []
 | 
			
		||||
    change_any = False
 | 
			
		||||
    txt_tmp = txt
 | 
			
		||||
    while True:
 | 
			
		||||
        re_match_obj = FENCED_BLOCK_RE.search(txt_tmp)
 | 
			
		||||
        if not re_match_obj: break
 | 
			
		||||
        if len(lines) == 0: lines = txt.split("\n")
 | 
			
		||||
        
 | 
			
		||||
        # 清空 txt_tmp 对应的位置方便下次搜索
 | 
			
		||||
        start_pos, end_pos = re_match_obj.regs[0]
 | 
			
		||||
        txt_tmp = txt_tmp[:start_pos] + ' '*(end_pos-start_pos) + txt_tmp[end_pos:]
 | 
			
		||||
        line_start, line_end = get_line_range(re_match_obj, txt)
 | 
			
		||||
        
 | 
			
		||||
        # 获取公共缩进
 | 
			
		||||
        shared_indent_cnt = 1e5
 | 
			
		||||
        for i in range(line_start, line_end):
 | 
			
		||||
            stripped_string = lines[i].lstrip()
 | 
			
		||||
            num_spaces = len(lines[i]) - len(stripped_string)
 | 
			
		||||
            if num_spaces < shared_indent_cnt:
 | 
			
		||||
                shared_indent_cnt = num_spaces
 | 
			
		||||
 | 
			
		||||
        # 修复缩进
 | 
			
		||||
        if (shared_indent_cnt < 1e5) and (shared_indent_cnt % 4) == 3:
 | 
			
		||||
            num_spaces_should_be = math.ceil(shared_indent_cnt / 4) * 4
 | 
			
		||||
            for i in range(line_start, line_end):
 | 
			
		||||
                add_n = num_spaces_should_be - shared_indent_cnt
 | 
			
		||||
                lines[i] = ' ' * add_n + lines[i]
 | 
			
		||||
            if not change_any: # 遇到第一个
 | 
			
		||||
                change_any = True
 | 
			
		||||
 | 
			
		||||
    if change_any:
 | 
			
		||||
        return '\n'.join(lines)
 | 
			
		||||
    else:
 | 
			
		||||
        return txt
 | 
			
		||||
    
 | 
			
		||||
@lru_cache(maxsize=128) # 使用 lru缓存 加快转换速度
 | 
			
		||||
def markdown_convertion(txt):
 | 
			
		||||
@ -52,92 +194,8 @@ def markdown_convertion(txt):
 | 
			
		||||
    }
 | 
			
		||||
    find_equation_pattern = r'<script type="math/tex(?:.*?)>(.*?)</script>'
 | 
			
		||||
 | 
			
		||||
    def tex2mathml_catch_exception(content, *args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            content = tex2mathml(content, *args, **kwargs)
 | 
			
		||||
        except:
 | 
			
		||||
            content = content
 | 
			
		||||
        return content
 | 
			
		||||
 | 
			
		||||
    def replace_math_no_render(match):
 | 
			
		||||
        content = match.group(1)
 | 
			
		||||
        if 'mode=display' in match.group(0):
 | 
			
		||||
            content = content.replace('\n', '</br>')
 | 
			
		||||
            return f"<font color=\"#00FF00\">$$</font><font color=\"#FF00FF\">{content}</font><font color=\"#00FF00\">$$</font>"
 | 
			
		||||
        else:
 | 
			
		||||
            return f"<font color=\"#00FF00\">$</font><font color=\"#FF00FF\">{content}</font><font color=\"#00FF00\">$</font>"
 | 
			
		||||
 | 
			
		||||
    def replace_math_render(match):
 | 
			
		||||
        content = match.group(1)
 | 
			
		||||
        if 'mode=display' in match.group(0):
 | 
			
		||||
            if '\\begin{aligned}' in content:
 | 
			
		||||
                content = content.replace('\\begin{aligned}', '\\begin{array}')
 | 
			
		||||
                content = content.replace('\\end{aligned}', '\\end{array}')
 | 
			
		||||
                content = content.replace('&', ' ')
 | 
			
		||||
            content = tex2mathml_catch_exception(content, display="block")
 | 
			
		||||
            return content
 | 
			
		||||
        else:
 | 
			
		||||
            return tex2mathml_catch_exception(content)
 | 
			
		||||
 | 
			
		||||
    def markdown_bug_hunt(content):
 | 
			
		||||
        """
 | 
			
		||||
        解决一个mdx_math的bug(单$包裹begin命令时多余<script>)
 | 
			
		||||
        """
 | 
			
		||||
        content = content.replace('<script type="math/tex">\n<script type="math/tex; mode=display">',
 | 
			
		||||
                                  '<script type="math/tex; mode=display">')
 | 
			
		||||
        content = content.replace('</script>\n</script>', '</script>')
 | 
			
		||||
        return content
 | 
			
		||||
 | 
			
		||||
    def is_equation(txt):
 | 
			
		||||
        """
 | 
			
		||||
        判定是否为公式 | 测试1 写出洛伦兹定律,使用tex格式公式 测试2 给出柯西不等式,使用latex格式 测试3 写出麦克斯韦方程组
 | 
			
		||||
        """
 | 
			
		||||
        if '```' in txt and '```reference' not in txt: return False
 | 
			
		||||
        if '$' not in txt and '\\[' not in txt: return False
 | 
			
		||||
        mathpatterns = {
 | 
			
		||||
            r'(?<!\\|\$)(\$)([^\$]+)(\$)': {'allow_multi_lines': False},                       #  $...$
 | 
			
		||||
            r'(?<!\\)(\$\$)([^\$]+)(\$\$)': {'allow_multi_lines': True},                       # $$...$$
 | 
			
		||||
            r'(?<!\\)(\\\[)(.+?)(\\\])': {'allow_multi_lines': False},                         # \[...\]
 | 
			
		||||
            # r'(?<!\\)(\\\()(.+?)(\\\))': {'allow_multi_lines': False},                       # \(...\)
 | 
			
		||||
            # r'(?<!\\)(\\begin{([a-z]+?\*?)})(.+?)(\\end{\2})': {'allow_multi_lines': True},  # \begin...\end
 | 
			
		||||
            # r'(?<!\\)(\$`)([^`]+)(`\$)': {'allow_multi_lines': False},                       # $`...`$
 | 
			
		||||
        }
 | 
			
		||||
        matches = []
 | 
			
		||||
        for pattern, property in mathpatterns.items():
 | 
			
		||||
            flags = re.ASCII | re.DOTALL if property['allow_multi_lines'] else re.ASCII
 | 
			
		||||
            matches.extend(re.findall(pattern, txt, flags))
 | 
			
		||||
        if len(matches) == 0: return False
 | 
			
		||||
        contain_any_eq = False
 | 
			
		||||
        illegal_pattern = re.compile(r'[^\x00-\x7F]|echo')
 | 
			
		||||
        for match in matches:
 | 
			
		||||
            if len(match) != 3: return False
 | 
			
		||||
            eq_canidate = match[1]
 | 
			
		||||
            if illegal_pattern.search(eq_canidate):
 | 
			
		||||
                return False
 | 
			
		||||
            else:
 | 
			
		||||
                contain_any_eq = True
 | 
			
		||||
        return contain_any_eq
 | 
			
		||||
 | 
			
		||||
    def fix_markdown_indent(txt):
 | 
			
		||||
        # fix markdown indent
 | 
			
		||||
        if (' - ' not in txt) or ('. ' not in txt):
 | 
			
		||||
            return txt  # do not need to fix, fast escape
 | 
			
		||||
        # walk through the lines and fix non-standard indentation
 | 
			
		||||
        lines = txt.split("\n")
 | 
			
		||||
        pattern = re.compile(r'^\s+-')
 | 
			
		||||
        activated = False
 | 
			
		||||
        for i, line in enumerate(lines):
 | 
			
		||||
            if line.startswith('- ') or line.startswith('1. '):
 | 
			
		||||
                activated = True
 | 
			
		||||
            if activated and pattern.match(line):
 | 
			
		||||
                stripped_string = line.lstrip()
 | 
			
		||||
                num_spaces = len(line) - len(stripped_string)
 | 
			
		||||
                if (num_spaces % 4) == 3:
 | 
			
		||||
                    num_spaces_should_be = math.ceil(num_spaces / 4) * 4
 | 
			
		||||
                    lines[i] = ' ' * num_spaces_should_be + stripped_string
 | 
			
		||||
        return '\n'.join(lines)
 | 
			
		||||
 | 
			
		||||
    txt = fix_markdown_indent(txt)
 | 
			
		||||
    txt = fix_code_segment_indent(txt)
 | 
			
		||||
    if is_equation(txt):  # 有$标识的公式符号,且没有代码段```的标识
 | 
			
		||||
        # convert everything to html format
 | 
			
		||||
        split = markdown.markdown(text='---')
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
接驳void-terminal:
 | 
			
		||||
    - set_conf:                     在运行过程中动态地修改配置
 | 
			
		||||
    - set_multi_conf:               在运行过程中动态地修改多个配置
 | 
			
		||||
@ -9,17 +9,20 @@ import os
 | 
			
		||||
    - get_plugin_default_kwargs:    获取插件的默认参数
 | 
			
		||||
    - get_chat_handle:              获取简单聊天的句柄
 | 
			
		||||
    - get_chat_default_kwargs:      获取简单聊天的默认参数
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_plugin_handle(plugin_name):
 | 
			
		||||
    """
 | 
			
		||||
    e.g. plugin_name = 'crazy_functions.批量Markdown翻译->Markdown翻译指定语言'
 | 
			
		||||
    """
 | 
			
		||||
    import importlib
 | 
			
		||||
    assert '->' in plugin_name, \
 | 
			
		||||
        "Example of plugin_name: crazy_functions.批量Markdown翻译->Markdown翻译指定语言"
 | 
			
		||||
    module, fn_name = plugin_name.split('->')
 | 
			
		||||
 | 
			
		||||
    assert (
 | 
			
		||||
        "->" in plugin_name
 | 
			
		||||
    ), "Example of plugin_name: crazy_functions.批量Markdown翻译->Markdown翻译指定语言"
 | 
			
		||||
    module, fn_name = plugin_name.split("->")
 | 
			
		||||
    f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
 | 
			
		||||
    return f_hot_reload
 | 
			
		||||
 | 
			
		||||
@ -29,6 +32,7 @@ def get_chat_handle():
 | 
			
		||||
    Get chat function
 | 
			
		||||
    """
 | 
			
		||||
    from request_llms.bridge_all import predict_no_ui_long_connection
 | 
			
		||||
 | 
			
		||||
    return predict_no_ui_long_connection
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -37,13 +41,14 @@ def get_plugin_default_kwargs():
 | 
			
		||||
    Get Plugin Default Arguments
 | 
			
		||||
    """
 | 
			
		||||
    from toolbox import ChatBotWithCookies, load_chat_cookies
 | 
			
		||||
 | 
			
		||||
    cookies = load_chat_cookies()
 | 
			
		||||
    llm_kwargs = {
 | 
			
		||||
        'api_key': cookies['api_key'],
 | 
			
		||||
        'llm_model': cookies['llm_model'],
 | 
			
		||||
        'top_p': 1.0,
 | 
			
		||||
        'max_length': None,
 | 
			
		||||
        'temperature': 1.0,
 | 
			
		||||
        "api_key": cookies["api_key"],
 | 
			
		||||
        "llm_model": cookies["llm_model"],
 | 
			
		||||
        "top_p": 1.0,
 | 
			
		||||
        "max_length": None,
 | 
			
		||||
        "temperature": 1.0,
 | 
			
		||||
    }
 | 
			
		||||
    chatbot = ChatBotWithCookies(llm_kwargs)
 | 
			
		||||
 | 
			
		||||
@ -55,7 +60,7 @@ def get_plugin_default_kwargs():
 | 
			
		||||
        "chatbot_with_cookie": chatbot,
 | 
			
		||||
        "history": [],
 | 
			
		||||
        "system_prompt": "You are a good AI.",
 | 
			
		||||
        "web_port": None
 | 
			
		||||
        "web_port": None,
 | 
			
		||||
    }
 | 
			
		||||
    return DEFAULT_FN_GROUPS_kwargs
 | 
			
		||||
 | 
			
		||||
@ -65,13 +70,14 @@ def get_chat_default_kwargs():
 | 
			
		||||
    Get Chat Default Arguments
 | 
			
		||||
    """
 | 
			
		||||
    from toolbox import load_chat_cookies
 | 
			
		||||
 | 
			
		||||
    cookies = load_chat_cookies()
 | 
			
		||||
    llm_kwargs = {
 | 
			
		||||
        'api_key': cookies['api_key'],
 | 
			
		||||
        'llm_model': cookies['llm_model'],
 | 
			
		||||
        'top_p': 1.0,
 | 
			
		||||
        'max_length': None,
 | 
			
		||||
        'temperature': 1.0,
 | 
			
		||||
        "api_key": cookies["api_key"],
 | 
			
		||||
        "llm_model": cookies["llm_model"],
 | 
			
		||||
        "top_p": 1.0,
 | 
			
		||||
        "max_length": None,
 | 
			
		||||
        "temperature": 1.0,
 | 
			
		||||
    }
 | 
			
		||||
    default_chat_kwargs = {
 | 
			
		||||
        "inputs": "Hello there, are you ready?",
 | 
			
		||||
 | 
			
		||||
@ -1,32 +1,75 @@
 | 
			
		||||
md = """
 | 
			
		||||
作为您的写作和编程助手,我可以为您提供以下服务:
 | 
			
		||||
 | 
			
		||||
1. 写作:
 | 
			
		||||
    - 帮助您撰写文章、报告、散文、故事等。
 | 
			
		||||
    - 提供写作建议和技巧。
 | 
			
		||||
    - 协助您进行文案策划和内容创作。
 | 
			
		||||
要计算文件的哈希值,可以使用哈希算法(如MD5、SHA-1或SHA-256)对文件的内容进行计算。
 | 
			
		||||
 | 
			
		||||
2. 编程:
 | 
			
		||||
    - 帮助您解决编程问题,提供编程思路和建议。
 | 
			
		||||
    - 协助您编写代码,包括但不限于 Python、Java、C++ 等。
 | 
			
		||||
    - 为您解释复杂的技术概念,让您更容易理解。
 | 
			
		||||
以下是一个使用sha256算法计算文件哈希值的示例代码:
 | 
			
		||||
 | 
			
		||||
3. 项目支持:
 | 
			
		||||
    - 协助您规划项目进度和任务分配。
 | 
			
		||||
    - 提供项目管理和协作建议。
 | 
			
		||||
    - 在项目实施过程中提供支持,确保项目顺利进行。
 | 
			
		||||
```python
 | 
			
		||||
import hashlib
 | 
			
		||||
 | 
			
		||||
def calculate_hash(file_path):
 | 
			
		||||
    sha256_hash = hashlib.sha256()
 | 
			
		||||
    with open(file_path, 'rb') as file:
 | 
			
		||||
        for chunk in iter(lambda: file.read(4096), b''):
 | 
			
		||||
            sha256_hash.update(chunk)
 | 
			
		||||
    return sha256_hash.hexdigest()
 | 
			
		||||
 | 
			
		||||
# 使用示例
 | 
			
		||||
file_path = 'path/to/file.txt'
 | 
			
		||||
hash_value = calculate_hash(file_path)
 | 
			
		||||
print('File hash:', hash_value)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
在上面的示例中,`calculate_hash`函数接受一个文件路径作为参数,并打开文件以二进制读取模式读取文件内容。然后,使用哈希对象sha256初始化,并对文件内容进行分块读取并更新哈希值。最后,通过`hexdigest`方法获取哈希值的十六进制表示。
 | 
			
		||||
 | 
			
		||||
可以根据需要更改哈希算法(如使用`hashlib.md5()`来使用MD5算法)和块大小(这里使用4096字节)。
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
md = """
 | 
			
		||||
要在Ubuntu中将NTFS格式转换为ext4格式,您需要进行以下步骤:
 | 
			
		||||
 | 
			
		||||
1. 首先,确保您已经安装了gparted软件。如果没有安装,请使用以下命令进行安装:
 | 
			
		||||
 | 
			
		||||
   ```
 | 
			
		||||
   sudo apt update
 | 
			
		||||
   sudo apt install gparted
 | 
			
		||||
   ```
 | 
			
		||||
 | 
			
		||||
2. 然后,打开GParted软件。您可以在"应用程序"菜单中搜索并启动它。
 | 
			
		||||
 | 
			
		||||
3. 在GParted界面中,选择您想要转换格式的NTFS分区。请小心选择,确保选择正确的分区。
 | 
			
		||||
 | 
			
		||||
4. 确保分区未挂载。如果分区当前正在使用,您需要首先卸载它。在命令行中,您可以使用以下命令卸载该分区:
 | 
			
		||||
 | 
			
		||||
    ```
 | 
			
		||||
    sudo umount /dev/sdc1
 | 
			
		||||
    ```
 | 
			
		||||
 | 
			
		||||
   注意:请将"/dev/sdc1"替换为您要卸载的分区的正确路径。
 | 
			
		||||
 | 
			
		||||
5. 在GParted界面中,单击菜单中的"设备"选项,然后选择"创建"。
 | 
			
		||||
 | 
			
		||||
6. 在弹出的对话框中,选择要转换为的文件系统类型。在这种情况下,选择"ext4"。然后单击"添加"按钮。
 | 
			
		||||
 | 
			
		||||
7. 在"操作"菜单中,选择"应用所有操作"。这将开始分区格式转换的过程。
 | 
			
		||||
 | 
			
		||||
8. 等待GParted完成转换操作。这可能需要一些时间,具体取决于分区的大小和系统性能。
 | 
			
		||||
 | 
			
		||||
9. 转换完成后,您将看到分区的文件系统已更改为ext4。
 | 
			
		||||
 | 
			
		||||
10. 最后,请确保挂载分区以便访问它。您可以使用以下命令挂载该分区:
 | 
			
		||||
 | 
			
		||||
    ```
 | 
			
		||||
    sudo mount /dev/sdc1 /media/fuqingxu/eb63a8fa-cee9-48a5-9f05-b1388c3fda9e
 | 
			
		||||
    ```
 | 
			
		||||
 | 
			
		||||
    注意:请将"/dev/sdc1"替换为已转换分区的正确路径,并将"/media/fuqingxu/eb63a8fa-cee9-48a5-9f05-b1388c3fda9e"替换为您要挂载的目标路径。
 | 
			
		||||
 | 
			
		||||
请注意,在执行任何分区操作之前,务必备份重要的数据。操作不当可能导致数据丢失。
 | 
			
		||||
 | 
			
		||||
4. 学习辅导:
 | 
			
		||||
    - 帮助您巩固编程基础,提高编程能力。
 | 
			
		||||
    - 提供计算机科学、数据科学、人工智能等相关领域的学习资源和建议。
 | 
			
		||||
    - 解答您在学习过程中遇到的问题,让您更好地掌握知识。
 | 
			
		||||
 | 
			
		||||
5. 行业动态和趋势分析:
 | 
			
		||||
    - 为您提供业界最新的新闻和技术趋势。
 | 
			
		||||
    - 分析行业动态,帮助您了解市场发展和竞争态势。
 | 
			
		||||
    - 为您制定技术战略提供参考和建议。
 | 
			
		||||
 | 
			
		||||
请随时告诉我您的需求,我会尽力提供帮助。如果您有任何问题或需要解答的议题,请随时提问。
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -43,6 +86,6 @@ validate_path()  # validate path so you can run from base directory
 | 
			
		||||
from toolbox import markdown_convertion
 | 
			
		||||
 | 
			
		||||
html = markdown_convertion(md)
 | 
			
		||||
print(html)
 | 
			
		||||
# print(html)
 | 
			
		||||
with open("test.html", "w", encoding="utf-8") as f:
 | 
			
		||||
    f.write(html)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										417
									
								
								toolbox.py
									
									
									
									
									
								
							
							
						
						
									
										417
									
								
								toolbox.py
									
									
									
									
									
								
							@ -11,6 +11,7 @@ from functools import wraps
 | 
			
		||||
from shared_utils.config_loader import get_conf
 | 
			
		||||
from shared_utils.config_loader import set_conf
 | 
			
		||||
from shared_utils.advanced_markdown_format import format_io
 | 
			
		||||
from shared_utils.advanced_markdown_format import markdown_convertion
 | 
			
		||||
from shared_utils.key_pattern_manager import select_api_key
 | 
			
		||||
from shared_utils.key_pattern_manager import is_any_api_key
 | 
			
		||||
from shared_utils.key_pattern_manager import what_keys
 | 
			
		||||
@ -20,10 +21,10 @@ from shared_utils.connect_void_terminal import get_plugin_default_kwargs
 | 
			
		||||
from shared_utils.connect_void_terminal import get_chat_default_kwargs
 | 
			
		||||
 | 
			
		||||
pj = os.path.join
 | 
			
		||||
default_user_name = 'default_user'
 | 
			
		||||
default_user_name = "default_user"
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
第一部分
 | 
			
		||||
函数插件输入输出接驳区
 | 
			
		||||
    - ChatBotWithCookies:   带Cookies的Chatbot类,为实现更多强大的功能做基础
 | 
			
		||||
@ -32,7 +33,7 @@ default_user_name = 'default_user'
 | 
			
		||||
    - CatchException:       将插件中出的所有问题显示在界面上
 | 
			
		||||
    - HotReload:            实现插件的热更新
 | 
			
		||||
    - trimmed_format_exc:   打印traceback,为了安全而隐藏绝对地址
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -120,22 +121,30 @@ def ArgsGeneralWrapper(f):
 | 
			
		||||
    return decorated
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_ui(chatbot, history, msg='正常', **kwargs):  # 刷新界面
 | 
			
		||||
def update_ui(chatbot, history, msg="正常", **kwargs):  # 刷新界面
 | 
			
		||||
    """
 | 
			
		||||
    刷新用户界面
 | 
			
		||||
    """
 | 
			
		||||
    assert isinstance(chatbot, ChatBotWithCookies), "在传递chatbot的过程中不要将其丢弃。必要时, 可用clear将其清空, 然后用for+append循环重新赋值。"
 | 
			
		||||
    assert isinstance(
 | 
			
		||||
        chatbot, ChatBotWithCookies
 | 
			
		||||
    ), "在传递chatbot的过程中不要将其丢弃。必要时, 可用clear将其清空, 然后用for+append循环重新赋值。"
 | 
			
		||||
    cookies = chatbot.get_cookies()
 | 
			
		||||
    # 备份一份History作为记录
 | 
			
		||||
    cookies.update({'history': history})
 | 
			
		||||
    cookies.update({"history": history})
 | 
			
		||||
    # 解决插件锁定时的界面显示问题
 | 
			
		||||
    if cookies.get('lock_plugin', None):
 | 
			
		||||
        label = cookies.get('llm_model', "") + " | " + "正在锁定插件" + cookies.get('lock_plugin', None)
 | 
			
		||||
    if cookies.get("lock_plugin", None):
 | 
			
		||||
        label = (
 | 
			
		||||
            cookies.get("llm_model", "")
 | 
			
		||||
            + " | "
 | 
			
		||||
            + "正在锁定插件"
 | 
			
		||||
            + cookies.get("lock_plugin", None)
 | 
			
		||||
        )
 | 
			
		||||
        chatbot_gr = gradio.update(value=chatbot, label=label)
 | 
			
		||||
        if cookies.get('label', "") != label: cookies['label'] = label   # 记住当前的label
 | 
			
		||||
    elif cookies.get('label', None):
 | 
			
		||||
        chatbot_gr = gradio.update(value=chatbot, label=cookies.get('llm_model', ""))
 | 
			
		||||
        cookies['label'] = None    # 清空label
 | 
			
		||||
        if cookies.get("label", "") != label:
 | 
			
		||||
            cookies["label"] = label  # 记住当前的label
 | 
			
		||||
    elif cookies.get("label", None):
 | 
			
		||||
        chatbot_gr = gradio.update(value=chatbot, label=cookies.get("llm_model", ""))
 | 
			
		||||
        cookies["label"] = None  # 清空label
 | 
			
		||||
    else:
 | 
			
		||||
        chatbot_gr = chatbot
 | 
			
		||||
 | 
			
		||||
@ -146,7 +155,8 @@ def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1):  # 刷新界面
 | 
			
		||||
    """
 | 
			
		||||
    刷新用户界面
 | 
			
		||||
    """
 | 
			
		||||
    if len(chatbot) == 0: chatbot.append(["update_ui_last_msg", lastmsg])
 | 
			
		||||
    if len(chatbot) == 0:
 | 
			
		||||
        chatbot.append(["update_ui_last_msg", lastmsg])
 | 
			
		||||
    chatbot[-1] = list(chatbot[-1])
 | 
			
		||||
    chatbot[-1][-1] = lastmsg
 | 
			
		||||
    yield from update_ui(chatbot=chatbot, history=history)
 | 
			
		||||
@ -155,6 +165,7 @@ def update_ui_lastest_msg(lastmsg, chatbot, history, delay=1):  # 刷新界面
 | 
			
		||||
 | 
			
		||||
def trimmed_format_exc():
 | 
			
		||||
    import os, traceback
 | 
			
		||||
 | 
			
		||||
    str = traceback.format_exc()
 | 
			
		||||
    current_path = os.getcwd()
 | 
			
		||||
    replace_path = "."
 | 
			
		||||
@ -194,19 +205,21 @@ def HotReload(f):
 | 
			
		||||
    最后,使用yield from语句返回重新加载过的函数,并在被装饰的函数上执行。
 | 
			
		||||
    最终,装饰器函数返回内部函数。这个内部函数可以将函数的原始定义更新为最新版本,并执行函数的新版本。
 | 
			
		||||
    """
 | 
			
		||||
    if get_conf('PLUGIN_HOT_RELOAD'):
 | 
			
		||||
    if get_conf("PLUGIN_HOT_RELOAD"):
 | 
			
		||||
 | 
			
		||||
        @wraps(f)
 | 
			
		||||
        def decorated(*args, **kwargs):
 | 
			
		||||
            fn_name = f.__name__
 | 
			
		||||
            f_hot_reload = getattr(importlib.reload(inspect.getmodule(f)), fn_name)
 | 
			
		||||
            yield from f_hot_reload(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        return decorated
 | 
			
		||||
    else:
 | 
			
		||||
        return f
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
第二部分
 | 
			
		||||
其他小工具:
 | 
			
		||||
    - write_history_to_file:    将结果写入markdown文件中
 | 
			
		||||
@ -220,13 +233,13 @@ def HotReload(f):
 | 
			
		||||
    - clip_history:             当历史上下文过长时,自动截断
 | 
			
		||||
    - get_conf:                 获取设置
 | 
			
		||||
    - select_api_key:           根据当前的模型类别,抽取可用的api-key
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_reduce_token_percent(text):
 | 
			
		||||
    """
 | 
			
		||||
        * 此函数未来将被弃用
 | 
			
		||||
    * 此函数未来将被弃用
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        # text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
 | 
			
		||||
@ -239,36 +252,40 @@ def get_reduce_token_percent(text):
 | 
			
		||||
        assert ratio > 0 and ratio < 1
 | 
			
		||||
        return ratio, str(int(current_tokens - max_limit))
 | 
			
		||||
    except:
 | 
			
		||||
        return 0.5, '不详'
 | 
			
		||||
        return 0.5, "不详"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_history_to_file(history, file_basename=None, file_fullname=None, auto_caption=True):
 | 
			
		||||
def write_history_to_file(
 | 
			
		||||
    history, file_basename=None, file_fullname=None, auto_caption=True
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
 | 
			
		||||
    """
 | 
			
		||||
    import os
 | 
			
		||||
    import time
 | 
			
		||||
 | 
			
		||||
    if file_fullname is None:
 | 
			
		||||
        if file_basename is not None:
 | 
			
		||||
            file_fullname = pj(get_log_folder(), file_basename)
 | 
			
		||||
        else:
 | 
			
		||||
            file_fullname = pj(get_log_folder(), f'GPT-Academic-{gen_time_str()}.md')
 | 
			
		||||
            file_fullname = pj(get_log_folder(), f"GPT-Academic-{gen_time_str()}.md")
 | 
			
		||||
    os.makedirs(os.path.dirname(file_fullname), exist_ok=True)
 | 
			
		||||
    with open(file_fullname, 'w', encoding='utf8') as f:
 | 
			
		||||
        f.write('# GPT-Academic Report\n')
 | 
			
		||||
    with open(file_fullname, "w", encoding="utf8") as f:
 | 
			
		||||
        f.write("# GPT-Academic Report\n")
 | 
			
		||||
        for i, content in enumerate(history):
 | 
			
		||||
            try:
 | 
			
		||||
                if type(content) != str: content = str(content)
 | 
			
		||||
                if type(content) != str:
 | 
			
		||||
                    content = str(content)
 | 
			
		||||
            except:
 | 
			
		||||
                continue
 | 
			
		||||
            if i % 2 == 0 and auto_caption:
 | 
			
		||||
                f.write('## ')
 | 
			
		||||
                f.write("## ")
 | 
			
		||||
            try:
 | 
			
		||||
                f.write(content)
 | 
			
		||||
            except:
 | 
			
		||||
                # remove everything that cannot be handled by utf8
 | 
			
		||||
                f.write(content.encode('utf-8', 'ignore').decode())
 | 
			
		||||
            f.write('\n\n')
 | 
			
		||||
                f.write(content.encode("utf-8", "ignore").decode())
 | 
			
		||||
            f.write("\n\n")
 | 
			
		||||
    res = os.path.abspath(file_fullname)
 | 
			
		||||
    return res
 | 
			
		||||
 | 
			
		||||
@ -277,9 +294,9 @@ def regular_txt_to_markdown(text):
 | 
			
		||||
    """
 | 
			
		||||
    将普通文本转换为Markdown格式的文本。
 | 
			
		||||
    """
 | 
			
		||||
    text = text.replace('\n', '\n\n')
 | 
			
		||||
    text = text.replace('\n\n\n', '\n\n')
 | 
			
		||||
    text = text.replace('\n\n\n', '\n\n')
 | 
			
		||||
    text = text.replace("\n", "\n\n")
 | 
			
		||||
    text = text.replace("\n\n\n", "\n\n")
 | 
			
		||||
    text = text.replace("\n\n\n", "\n\n")
 | 
			
		||||
    return text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -297,8 +314,9 @@ def find_free_port():
 | 
			
		||||
    """
 | 
			
		||||
    import socket
 | 
			
		||||
    from contextlib import closing
 | 
			
		||||
 | 
			
		||||
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
 | 
			
		||||
        s.bind(('', 0))
 | 
			
		||||
        s.bind(("", 0))
 | 
			
		||||
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 | 
			
		||||
        return s.getsockname()[1]
 | 
			
		||||
 | 
			
		||||
@ -307,54 +325,58 @@ def extract_archive(file_path, dest_dir):
 | 
			
		||||
    import zipfile
 | 
			
		||||
    import tarfile
 | 
			
		||||
    import os
 | 
			
		||||
 | 
			
		||||
    # Get the file extension of the input file
 | 
			
		||||
    file_extension = os.path.splitext(file_path)[1]
 | 
			
		||||
 | 
			
		||||
    # Extract the archive based on its extension
 | 
			
		||||
    if file_extension == '.zip':
 | 
			
		||||
        with zipfile.ZipFile(file_path, 'r') as zipobj:
 | 
			
		||||
    if file_extension == ".zip":
 | 
			
		||||
        with zipfile.ZipFile(file_path, "r") as zipobj:
 | 
			
		||||
            zipobj.extractall(path=dest_dir)
 | 
			
		||||
            print("Successfully extracted zip archive to {}".format(dest_dir))
 | 
			
		||||
 | 
			
		||||
    elif file_extension in ['.tar', '.gz', '.bz2']:
 | 
			
		||||
        with tarfile.open(file_path, 'r:*') as tarobj:
 | 
			
		||||
    elif file_extension in [".tar", ".gz", ".bz2"]:
 | 
			
		||||
        with tarfile.open(file_path, "r:*") as tarobj:
 | 
			
		||||
            tarobj.extractall(path=dest_dir)
 | 
			
		||||
            print("Successfully extracted tar archive to {}".format(dest_dir))
 | 
			
		||||
 | 
			
		||||
    # 第三方库,需要预先pip install rarfile
 | 
			
		||||
    # 此外,Windows上还需要安装winrar软件,配置其Path环境变量,如"C:\Program Files\WinRAR"才可以
 | 
			
		||||
    elif file_extension == '.rar':
 | 
			
		||||
    elif file_extension == ".rar":
 | 
			
		||||
        try:
 | 
			
		||||
            import rarfile
 | 
			
		||||
 | 
			
		||||
            with rarfile.RarFile(file_path) as rf:
 | 
			
		||||
                rf.extractall(path=dest_dir)
 | 
			
		||||
                print("Successfully extracted rar archive to {}".format(dest_dir))
 | 
			
		||||
        except:
 | 
			
		||||
            print("Rar format requires additional dependencies to install")
 | 
			
		||||
            return '\n\n解压失败! 需要安装pip install rarfile来解压rar文件。建议:使用zip压缩格式。'
 | 
			
		||||
            return "\n\n解压失败! 需要安装pip install rarfile来解压rar文件。建议:使用zip压缩格式。"
 | 
			
		||||
 | 
			
		||||
    # 第三方库,需要预先pip install py7zr
 | 
			
		||||
    elif file_extension == '.7z':
 | 
			
		||||
    elif file_extension == ".7z":
 | 
			
		||||
        try:
 | 
			
		||||
            import py7zr
 | 
			
		||||
            with py7zr.SevenZipFile(file_path, mode='r') as f:
 | 
			
		||||
 | 
			
		||||
            with py7zr.SevenZipFile(file_path, mode="r") as f:
 | 
			
		||||
                f.extractall(path=dest_dir)
 | 
			
		||||
                print("Successfully extracted 7z archive to {}".format(dest_dir))
 | 
			
		||||
        except:
 | 
			
		||||
            print("7z format requires additional dependencies to install")
 | 
			
		||||
            return '\n\n解压失败! 需要安装pip install py7zr来解压7z文件'
 | 
			
		||||
            return "\n\n解压失败! 需要安装pip install py7zr来解压7z文件"
 | 
			
		||||
    else:
 | 
			
		||||
        return ''
 | 
			
		||||
    return ''
 | 
			
		||||
        return ""
 | 
			
		||||
    return ""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_recent_files(directory):
 | 
			
		||||
    """
 | 
			
		||||
        me: find files that is created with in one minutes under a directory with python, write a function
 | 
			
		||||
        gpt: here it is!
 | 
			
		||||
    me: find files that is created with in one minutes under a directory with python, write a function
 | 
			
		||||
    gpt: here it is!
 | 
			
		||||
    """
 | 
			
		||||
    import os
 | 
			
		||||
    import time
 | 
			
		||||
 | 
			
		||||
    current_time = time.time()
 | 
			
		||||
    one_minute_ago = current_time - 60
 | 
			
		||||
    recent_files = []
 | 
			
		||||
@ -362,7 +384,7 @@ def find_recent_files(directory):
 | 
			
		||||
        os.makedirs(directory, exist_ok=True)
 | 
			
		||||
    for filename in os.listdir(directory):
 | 
			
		||||
        file_path = pj(directory, filename)
 | 
			
		||||
        if file_path.endswith('.log'):
 | 
			
		||||
        if file_path.endswith(".log"):
 | 
			
		||||
            continue
 | 
			
		||||
        created_time = os.path.getmtime(file_path)
 | 
			
		||||
        if created_time >= one_minute_ago:
 | 
			
		||||
@ -388,49 +410,53 @@ def file_already_in_downloadzone(file, user_path):
 | 
			
		||||
def promote_file_to_downloadzone(file, rename_file=None, chatbot=None):
 | 
			
		||||
    # 将文件复制一份到下载区
 | 
			
		||||
    import shutil
 | 
			
		||||
 | 
			
		||||
    if chatbot is not None:
 | 
			
		||||
        user_name = get_user(chatbot)
 | 
			
		||||
    else:
 | 
			
		||||
        user_name = default_user_name
 | 
			
		||||
    if not os.path.exists(file):
 | 
			
		||||
        raise FileNotFoundError(f'文件{file}不存在')
 | 
			
		||||
        raise FileNotFoundError(f"文件{file}不存在")
 | 
			
		||||
    user_path = get_log_folder(user_name, plugin_name=None)
 | 
			
		||||
    if file_already_in_downloadzone(file, user_path):
 | 
			
		||||
        new_path = file
 | 
			
		||||
    else:
 | 
			
		||||
        user_path = get_log_folder(user_name, plugin_name='downloadzone')
 | 
			
		||||
        if rename_file is None: rename_file = f'{gen_time_str()}-{os.path.basename(file)}'
 | 
			
		||||
        user_path = get_log_folder(user_name, plugin_name="downloadzone")
 | 
			
		||||
        if rename_file is None:
 | 
			
		||||
            rename_file = f"{gen_time_str()}-{os.path.basename(file)}"
 | 
			
		||||
        new_path = pj(user_path, rename_file)
 | 
			
		||||
        # 如果已经存在,先删除
 | 
			
		||||
        if os.path.exists(new_path) and not os.path.samefile(new_path, file): os.remove(new_path)
 | 
			
		||||
        if os.path.exists(new_path) and not os.path.samefile(new_path, file):
 | 
			
		||||
            os.remove(new_path)
 | 
			
		||||
        # 把文件复制过去
 | 
			
		||||
        if not os.path.exists(new_path): shutil.copyfile(file, new_path)
 | 
			
		||||
        if not os.path.exists(new_path):
 | 
			
		||||
            shutil.copyfile(file, new_path)
 | 
			
		||||
    # 将文件添加到chatbot cookie中
 | 
			
		||||
    if chatbot is not None:
 | 
			
		||||
        if 'files_to_promote' in chatbot._cookies:
 | 
			
		||||
            current = chatbot._cookies['files_to_promote']
 | 
			
		||||
        if "files_to_promote" in chatbot._cookies:
 | 
			
		||||
            current = chatbot._cookies["files_to_promote"]
 | 
			
		||||
        else:
 | 
			
		||||
            current = []
 | 
			
		||||
        if new_path not in current: # 避免把同一个文件添加多次
 | 
			
		||||
            chatbot._cookies.update({'files_to_promote': [new_path] + current})
 | 
			
		||||
        if new_path not in current:  # 避免把同一个文件添加多次
 | 
			
		||||
            chatbot._cookies.update({"files_to_promote": [new_path] + current})
 | 
			
		||||
    return new_path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def disable_auto_promotion(chatbot):
 | 
			
		||||
    chatbot._cookies.update({'files_to_promote': []})
 | 
			
		||||
    chatbot._cookies.update({"files_to_promote": []})
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def del_outdated_uploads(outdate_time_seconds, target_path_base=None):
 | 
			
		||||
    if target_path_base is None:
 | 
			
		||||
        user_upload_dir = get_conf('PATH_PRIVATE_UPLOAD')
 | 
			
		||||
        user_upload_dir = get_conf("PATH_PRIVATE_UPLOAD")
 | 
			
		||||
    else:
 | 
			
		||||
        user_upload_dir = target_path_base
 | 
			
		||||
    current_time = time.time()
 | 
			
		||||
    one_hour_ago = current_time - outdate_time_seconds
 | 
			
		||||
    # Get a list of all subdirectories in the user_upload_dir folder
 | 
			
		||||
    # Remove subdirectories that are older than one hour
 | 
			
		||||
    for subdirectory in glob.glob(f'{user_upload_dir}/*'):
 | 
			
		||||
    for subdirectory in glob.glob(f"{user_upload_dir}/*"):
 | 
			
		||||
        subdirectory_time = os.path.getmtime(subdirectory)
 | 
			
		||||
        if subdirectory_time < one_hour_ago:
 | 
			
		||||
            try:
 | 
			
		||||
@ -447,8 +473,8 @@ def html_local_file(file):
 | 
			
		||||
    return file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
 | 
			
		||||
    style = ''
 | 
			
		||||
def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
 | 
			
		||||
    style = ""
 | 
			
		||||
    if max_width is not None:
 | 
			
		||||
        style += f"max-width: {max_width};"
 | 
			
		||||
    if max_height is not None:
 | 
			
		||||
@ -456,20 +482,23 @@ def html_local_img(__file, layout='left', max_width=None, max_height=None, md=Tr
 | 
			
		||||
    __file = html_local_file(__file)
 | 
			
		||||
    a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
 | 
			
		||||
    if md:
 | 
			
		||||
        a = f''
 | 
			
		||||
        a = f""
 | 
			
		||||
    return a
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def file_manifest_filter_type(file_list, filter_: list = None):
 | 
			
		||||
    new_list = []
 | 
			
		||||
    if not filter_: filter_ = ['png', 'jpg', 'jpeg']
 | 
			
		||||
    if not filter_:
 | 
			
		||||
        filter_ = ["png", "jpg", "jpeg"]
 | 
			
		||||
    for file in file_list:
 | 
			
		||||
        if str(os.path.basename(file)).split('.')[-1] in filter_:
 | 
			
		||||
        if str(os.path.basename(file)).split(".")[-1] in filter_:
 | 
			
		||||
            new_list.append(html_local_img(file, md=False))
 | 
			
		||||
        else:
 | 
			
		||||
            new_list.append(file)
 | 
			
		||||
    return new_list
 | 
			
		||||
 | 
			
		||||
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
 | 
			
		||||
 | 
			
		||||
def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False):
 | 
			
		||||
    """
 | 
			
		||||
    Args:
 | 
			
		||||
        head: 表头:[]
 | 
			
		||||
@ -487,17 +516,20 @@ def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
 | 
			
		||||
    max_len = max(len(column) for column in transposed_tabs)
 | 
			
		||||
 | 
			
		||||
    tab_format = "| %s "
 | 
			
		||||
    tabs_list = "".join([tab_format % i for i in head]) + '|\n'
 | 
			
		||||
    tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
 | 
			
		||||
    tabs_list = "".join([tab_format % i for i in head]) + "|\n"
 | 
			
		||||
    tabs_list += "".join([tab_format % alignment for i in head]) + "|\n"
 | 
			
		||||
 | 
			
		||||
    for i in range(max_len):
 | 
			
		||||
        row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
 | 
			
		||||
        row_data = [tab[i] if i < len(tab) else "" for tab in transposed_tabs]
 | 
			
		||||
        row_data = file_manifest_filter_type(row_data, filter_=None)
 | 
			
		||||
        tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
 | 
			
		||||
        tabs_list += "".join([tab_format % i for i in row_data]) + "|\n"
 | 
			
		||||
 | 
			
		||||
    return tabs_list
 | 
			
		||||
 | 
			
		||||
def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkboxes, cookies):
 | 
			
		||||
 | 
			
		||||
def on_file_uploaded(
 | 
			
		||||
    request: gradio.Request, files, chatbot, txt, txt2, checkboxes, cookies
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    当文件被上传时的回调函数
 | 
			
		||||
    """
 | 
			
		||||
@ -515,94 +547,118 @@ def on_file_uploaded(request: gradio.Request, files, chatbot, txt, txt2, checkbo
 | 
			
		||||
    del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name))
 | 
			
		||||
 | 
			
		||||
    # 逐个文件转移到目标路径
 | 
			
		||||
    upload_msg = ''
 | 
			
		||||
    upload_msg = ""
 | 
			
		||||
    for file in files:
 | 
			
		||||
        file_origin_name = os.path.basename(file.orig_name)
 | 
			
		||||
        this_file_path = pj(target_path_base, file_origin_name)
 | 
			
		||||
        shutil.move(file.name, this_file_path)
 | 
			
		||||
        upload_msg += extract_archive(file_path=this_file_path, dest_dir=this_file_path + '.extract')
 | 
			
		||||
        upload_msg += extract_archive(
 | 
			
		||||
            file_path=this_file_path, dest_dir=this_file_path + ".extract"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # 整理文件集合 输出消息
 | 
			
		||||
    moved_files = [fp for fp in glob.glob(f'{target_path_base}/**/*', recursive=True)]
 | 
			
		||||
    moved_files_str = to_markdown_tabs(head=['文件'], tabs=[moved_files])
 | 
			
		||||
    chatbot.append(['我上传了文件,请查收',
 | 
			
		||||
                    f'[Local Message] 收到以下文件: \n\n{moved_files_str}' +
 | 
			
		||||
                    f'\n\n调用路径参数已自动修正到: \n\n{txt}' +
 | 
			
		||||
                    f'\n\n现在您点击任意函数插件时,以上文件将被作为输入参数' + upload_msg])
 | 
			
		||||
    moved_files = [fp for fp in glob.glob(f"{target_path_base}/**/*", recursive=True)]
 | 
			
		||||
    moved_files_str = to_markdown_tabs(head=["文件"], tabs=[moved_files])
 | 
			
		||||
    chatbot.append(
 | 
			
		||||
        [
 | 
			
		||||
            "我上传了文件,请查收",
 | 
			
		||||
            f"[Local Message] 收到以下文件: \n\n{moved_files_str}"
 | 
			
		||||
            + f"\n\n调用路径参数已自动修正到: \n\n{txt}"
 | 
			
		||||
            + f"\n\n现在您点击任意函数插件时,以上文件将被作为输入参数"
 | 
			
		||||
            + upload_msg,
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    txt, txt2 = target_path_base, ""
 | 
			
		||||
    if "浮动输入区" in checkboxes:
 | 
			
		||||
        txt, txt2 = txt2, txt
 | 
			
		||||
 | 
			
		||||
    # 记录近期文件
 | 
			
		||||
    cookies.update({
 | 
			
		||||
        'most_recent_uploaded': {
 | 
			
		||||
            'path': target_path_base,
 | 
			
		||||
            'time': time.time(),
 | 
			
		||||
            'time_str': time_tag
 | 
			
		||||
    }})
 | 
			
		||||
    cookies.update(
 | 
			
		||||
        {
 | 
			
		||||
            "most_recent_uploaded": {
 | 
			
		||||
                "path": target_path_base,
 | 
			
		||||
                "time": time.time(),
 | 
			
		||||
                "time_str": time_tag,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
    return chatbot, txt, txt2, cookies
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def on_report_generated(cookies, files, chatbot):
 | 
			
		||||
    # from toolbox import find_recent_files
 | 
			
		||||
    # PATH_LOGGING = get_conf('PATH_LOGGING')
 | 
			
		||||
    if 'files_to_promote' in cookies:
 | 
			
		||||
        report_files = cookies['files_to_promote']
 | 
			
		||||
        cookies.pop('files_to_promote')
 | 
			
		||||
    if "files_to_promote" in cookies:
 | 
			
		||||
        report_files = cookies["files_to_promote"]
 | 
			
		||||
        cookies.pop("files_to_promote")
 | 
			
		||||
    else:
 | 
			
		||||
        report_files = []
 | 
			
		||||
    #     report_files = find_recent_files(PATH_LOGGING)
 | 
			
		||||
    if len(report_files) == 0:
 | 
			
		||||
        return cookies, None, chatbot
 | 
			
		||||
    # files.extend(report_files)
 | 
			
		||||
    file_links = ''
 | 
			
		||||
    for f in report_files: file_links += f'<br/><a href="file={os.path.abspath(f)}" target="_blank">{f}</a>'
 | 
			
		||||
    chatbot.append(['报告如何远程获取?', f'报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}'])
 | 
			
		||||
    file_links = ""
 | 
			
		||||
    for f in report_files:
 | 
			
		||||
        file_links += (
 | 
			
		||||
            f'<br/><a href="file={os.path.abspath(f)}" target="_blank">{f}</a>'
 | 
			
		||||
        )
 | 
			
		||||
    chatbot.append(["报告如何远程获取?", f"报告已经添加到右侧“文件上传区”(可能处于折叠状态),请查收。{file_links}"])
 | 
			
		||||
    return cookies, report_files, chatbot
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_chat_cookies():
 | 
			
		||||
    API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf('API_KEY', 'LLM_MODEL', 'AZURE_API_KEY')
 | 
			
		||||
    AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf('AZURE_CFG_ARRAY', 'NUM_CUSTOM_BASIC_BTN')
 | 
			
		||||
    API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf(
 | 
			
		||||
        "API_KEY", "LLM_MODEL", "AZURE_API_KEY"
 | 
			
		||||
    )
 | 
			
		||||
    AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf(
 | 
			
		||||
        "AZURE_CFG_ARRAY", "NUM_CUSTOM_BASIC_BTN"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # deal with azure openai key
 | 
			
		||||
    if is_any_api_key(AZURE_API_KEY):
 | 
			
		||||
        if is_any_api_key(API_KEY):
 | 
			
		||||
            API_KEY = API_KEY + ',' + AZURE_API_KEY
 | 
			
		||||
            API_KEY = API_KEY + "," + AZURE_API_KEY
 | 
			
		||||
        else:
 | 
			
		||||
            API_KEY = AZURE_API_KEY
 | 
			
		||||
    if len(AZURE_CFG_ARRAY) > 0:
 | 
			
		||||
        for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
 | 
			
		||||
            if not azure_model_name.startswith('azure'):
 | 
			
		||||
            if not azure_model_name.startswith("azure"):
 | 
			
		||||
                raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
 | 
			
		||||
            AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
 | 
			
		||||
            if is_any_api_key(AZURE_API_KEY_):
 | 
			
		||||
                if is_any_api_key(API_KEY):
 | 
			
		||||
                    API_KEY = API_KEY + ',' + AZURE_API_KEY_
 | 
			
		||||
                    API_KEY = API_KEY + "," + AZURE_API_KEY_
 | 
			
		||||
                else:
 | 
			
		||||
                    API_KEY = AZURE_API_KEY_
 | 
			
		||||
 | 
			
		||||
    customize_fn_overwrite_ = {}
 | 
			
		||||
    for k in range(NUM_CUSTOM_BASIC_BTN):
 | 
			
		||||
        customize_fn_overwrite_.update({
 | 
			
		||||
            "自定义按钮" + str(k+1):{
 | 
			
		||||
                "Title":  r"",
 | 
			
		||||
                "Prefix": r"请在自定义菜单中定义提示词前缀.",
 | 
			
		||||
                "Suffix": r"请在自定义菜单中定义提示词后缀",
 | 
			
		||||
        customize_fn_overwrite_.update(
 | 
			
		||||
            {
 | 
			
		||||
                "自定义按钮"
 | 
			
		||||
                + str(k + 1): {
 | 
			
		||||
                    "Title": r"",
 | 
			
		||||
                    "Prefix": r"请在自定义菜单中定义提示词前缀.",
 | 
			
		||||
                    "Suffix": r"请在自定义菜单中定义提示词后缀",
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        })
 | 
			
		||||
    return {'api_key': API_KEY, 'llm_model': LLM_MODEL, 'customize_fn_overwrite': customize_fn_overwrite_}
 | 
			
		||||
        )
 | 
			
		||||
    return {
 | 
			
		||||
        "api_key": API_KEY,
 | 
			
		||||
        "llm_model": LLM_MODEL,
 | 
			
		||||
        "customize_fn_overwrite": customize_fn_overwrite_,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clear_line_break(txt):
 | 
			
		||||
    txt = txt.replace('\n', ' ')
 | 
			
		||||
    txt = txt.replace('  ', ' ')
 | 
			
		||||
    txt = txt.replace('  ', ' ')
 | 
			
		||||
    txt = txt.replace("\n", " ")
 | 
			
		||||
    txt = txt.replace("  ", " ")
 | 
			
		||||
    txt = txt.replace("  ", " ")
 | 
			
		||||
    return txt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DummyWith():
 | 
			
		||||
class DummyWith:
 | 
			
		||||
    """
 | 
			
		||||
    这段代码定义了一个名为DummyWith的空上下文管理器,
 | 
			
		||||
    它的作用是……额……就是不起作用,即在代码结构不变得情况下取代其他的上下文管理器。
 | 
			
		||||
@ -626,34 +682,47 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def is_path_legal(path: str) -> bool:
 | 
			
		||||
        '''
 | 
			
		||||
        """
 | 
			
		||||
        check path for sub url
 | 
			
		||||
        path: path to check
 | 
			
		||||
        return value: do sub url wrap
 | 
			
		||||
        '''
 | 
			
		||||
        if path == "/": return True
 | 
			
		||||
        """
 | 
			
		||||
        if path == "/":
 | 
			
		||||
            return True
 | 
			
		||||
        if len(path) == 0:
 | 
			
		||||
            print("ilegal custom path: {}\npath must not be empty\ndeploy on root url".format(path))
 | 
			
		||||
            print(
 | 
			
		||||
                "ilegal custom path: {}\npath must not be empty\ndeploy on root url".format(
 | 
			
		||||
                    path
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            return False
 | 
			
		||||
        if path[0] == '/':
 | 
			
		||||
            if path[1] != '/':
 | 
			
		||||
        if path[0] == "/":
 | 
			
		||||
            if path[1] != "/":
 | 
			
		||||
                print("deploy on sub-path {}".format(path))
 | 
			
		||||
                return True
 | 
			
		||||
            return False
 | 
			
		||||
        print("ilegal custom path: {}\npath should begin with \'/\'\ndeploy on root url".format(path))
 | 
			
		||||
        print(
 | 
			
		||||
            "ilegal custom path: {}\npath should begin with '/'\ndeploy on root url".format(
 | 
			
		||||
                path
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    if not is_path_legal(custom_path): raise RuntimeError('Ilegal custom path')
 | 
			
		||||
    if not is_path_legal(custom_path):
 | 
			
		||||
        raise RuntimeError("Ilegal custom path")
 | 
			
		||||
    import uvicorn
 | 
			
		||||
    import gradio as gr
 | 
			
		||||
    from fastapi import FastAPI
 | 
			
		||||
 | 
			
		||||
    app = FastAPI()
 | 
			
		||||
    if custom_path != "/":
 | 
			
		||||
 | 
			
		||||
        @app.get("/")
 | 
			
		||||
        def read_main():
 | 
			
		||||
            return {"message": f"Gradio is running at: {custom_path}"}
 | 
			
		||||
 | 
			
		||||
    app = gr.mount_gradio_app(app, demo, path=custom_path)
 | 
			
		||||
    uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
 | 
			
		||||
    uvicorn.run(app, host="0.0.0.0", port=port)  # , auth=auth
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clip_history(inputs, history, tokenizer, max_token_limit):
 | 
			
		||||
@ -667,13 +736,18 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
 | 
			
		||||
    """
 | 
			
		||||
    import numpy as np
 | 
			
		||||
    from request_llms.bridge_all import model_info
 | 
			
		||||
 | 
			
		||||
    def get_token_num(txt):
 | 
			
		||||
        return len(tokenizer.encode(txt, disallowed_special=()))
 | 
			
		||||
 | 
			
		||||
    input_token_num = get_token_num(inputs)
 | 
			
		||||
 | 
			
		||||
    if max_token_limit < 5000:   output_token_expect = 256  # 4k & 2k models
 | 
			
		||||
    elif max_token_limit < 9000: output_token_expect = 512  # 8k models
 | 
			
		||||
    else: output_token_expect = 1024                        # 16k & 32k models
 | 
			
		||||
    if max_token_limit < 5000:
 | 
			
		||||
        output_token_expect = 256  # 4k & 2k models
 | 
			
		||||
    elif max_token_limit < 9000:
 | 
			
		||||
        output_token_expect = 512  # 8k models
 | 
			
		||||
    else:
 | 
			
		||||
        output_token_expect = 1024  # 16k & 32k models
 | 
			
		||||
 | 
			
		||||
    if input_token_num < max_token_limit * 3 / 4:
 | 
			
		||||
        # 当输入部分的token占比小于限制的3/4时,裁剪时
 | 
			
		||||
@ -690,9 +764,9 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
 | 
			
		||||
        history = []
 | 
			
		||||
        return history
 | 
			
		||||
 | 
			
		||||
    everything = ['']
 | 
			
		||||
    everything = [""]
 | 
			
		||||
    everything.extend(history)
 | 
			
		||||
    n_token = get_token_num('\n'.join(everything))
 | 
			
		||||
    n_token = get_token_num("\n".join(everything))
 | 
			
		||||
    everything_token = [get_token_num(e) for e in everything]
 | 
			
		||||
 | 
			
		||||
    # 截断时的颗粒度
 | 
			
		||||
@ -701,30 +775,33 @@ def clip_history(inputs, history, tokenizer, max_token_limit):
 | 
			
		||||
    while n_token > max_token_limit:
 | 
			
		||||
        where = np.argmax(everything_token)
 | 
			
		||||
        encoded = tokenizer.encode(everything[where], disallowed_special=())
 | 
			
		||||
        clipped_encoded = encoded[:len(encoded) - delta]
 | 
			
		||||
        everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
 | 
			
		||||
        clipped_encoded = encoded[: len(encoded) - delta]
 | 
			
		||||
        everything[where] = tokenizer.decode(clipped_encoded)[
 | 
			
		||||
            :-1
 | 
			
		||||
        ]  # -1 to remove the may-be illegal char
 | 
			
		||||
        everything_token[where] = get_token_num(everything[where])
 | 
			
		||||
        n_token = get_token_num('\n'.join(everything))
 | 
			
		||||
        n_token = get_token_num("\n".join(everything))
 | 
			
		||||
 | 
			
		||||
    history = everything[1:]
 | 
			
		||||
    return history
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
第三部分
 | 
			
		||||
其他小工具:
 | 
			
		||||
    - zip_folder:    把某个路径下所有文件压缩,然后转移到指定的另一个路径中(gpt写的)
 | 
			
		||||
    - gen_time_str:  生成时间戳
 | 
			
		||||
    - ProxyNetworkActivate: 临时地启动代理网络(如果有)
 | 
			
		||||
    - objdump/objload: 快捷的调试函数
 | 
			
		||||
========================================================================
 | 
			
		||||
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def zip_folder(source_folder, dest_folder, zip_name):
 | 
			
		||||
    import zipfile
 | 
			
		||||
    import os
 | 
			
		||||
 | 
			
		||||
    # Make sure the source folder exists
 | 
			
		||||
    if not os.path.exists(source_folder):
 | 
			
		||||
        print(f"{source_folder} does not exist")
 | 
			
		||||
@ -739,7 +816,7 @@ def zip_folder(source_folder, dest_folder, zip_name):
 | 
			
		||||
    zip_file = pj(dest_folder, zip_name)
 | 
			
		||||
 | 
			
		||||
    # Create a ZipFile object
 | 
			
		||||
    with zipfile.ZipFile(zip_file, 'w', zipfile.ZIP_DEFLATED) as zipf:
 | 
			
		||||
    with zipfile.ZipFile(zip_file, "w", zipfile.ZIP_DEFLATED) as zipf:
 | 
			
		||||
        # Walk through the source folder and add files to the zip file
 | 
			
		||||
        for foldername, subfolders, filenames in os.walk(source_folder):
 | 
			
		||||
            for filename in filenames:
 | 
			
		||||
@ -756,29 +833,33 @@ def zip_folder(source_folder, dest_folder, zip_name):
 | 
			
		||||
 | 
			
		||||
def zip_result(folder):
 | 
			
		||||
    t = gen_time_str()
 | 
			
		||||
    zip_folder(folder, get_log_folder(), f'{t}-result.zip')
 | 
			
		||||
    return pj(get_log_folder(), f'{t}-result.zip')
 | 
			
		||||
    zip_folder(folder, get_log_folder(), f"{t}-result.zip")
 | 
			
		||||
    return pj(get_log_folder(), f"{t}-result.zip")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gen_time_str():
 | 
			
		||||
    import time
 | 
			
		||||
 | 
			
		||||
    return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_log_folder(user=default_user_name, plugin_name='shared'):
 | 
			
		||||
    if user is None: user = default_user_name
 | 
			
		||||
    PATH_LOGGING = get_conf('PATH_LOGGING')
 | 
			
		||||
def get_log_folder(user=default_user_name, plugin_name="shared"):
 | 
			
		||||
    if user is None:
 | 
			
		||||
        user = default_user_name
 | 
			
		||||
    PATH_LOGGING = get_conf("PATH_LOGGING")
 | 
			
		||||
    if plugin_name is None:
 | 
			
		||||
        _dir = pj(PATH_LOGGING, user)
 | 
			
		||||
    else:
 | 
			
		||||
        _dir = pj(PATH_LOGGING, user, plugin_name)
 | 
			
		||||
    if not os.path.exists(_dir): os.makedirs(_dir)
 | 
			
		||||
    if not os.path.exists(_dir):
 | 
			
		||||
        os.makedirs(_dir)
 | 
			
		||||
    return _dir
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_upload_folder(user=default_user_name, tag=None):
 | 
			
		||||
    PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
 | 
			
		||||
    if user is None: user = default_user_name
 | 
			
		||||
    PATH_PRIVATE_UPLOAD = get_conf("PATH_PRIVATE_UPLOAD")
 | 
			
		||||
    if user is None:
 | 
			
		||||
        user = default_user_name
 | 
			
		||||
    if tag is None or len(tag) == 0:
 | 
			
		||||
        target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
 | 
			
		||||
    else:
 | 
			
		||||
@ -787,9 +868,9 @@ def get_upload_folder(user=default_user_name, tag=None):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_the_upload_folder(string):
 | 
			
		||||
    PATH_PRIVATE_UPLOAD = get_conf('PATH_PRIVATE_UPLOAD')
 | 
			
		||||
    pattern = r'^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$'
 | 
			
		||||
    pattern = pattern.replace('PATH_PRIVATE_UPLOAD', PATH_PRIVATE_UPLOAD)
 | 
			
		||||
    PATH_PRIVATE_UPLOAD = get_conf("PATH_PRIVATE_UPLOAD")
 | 
			
		||||
    pattern = r"^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$"
 | 
			
		||||
    pattern = pattern.replace("PATH_PRIVATE_UPLOAD", PATH_PRIVATE_UPLOAD)
 | 
			
		||||
    if re.match(pattern, string):
 | 
			
		||||
        return True
 | 
			
		||||
    else:
 | 
			
		||||
@ -797,10 +878,10 @@ def is_the_upload_folder(string):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_user(chatbotwithcookies):
 | 
			
		||||
    return chatbotwithcookies._cookies.get('user_name', default_user_name)
 | 
			
		||||
    return chatbotwithcookies._cookies.get("user_name", default_user_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProxyNetworkActivate():
 | 
			
		||||
class ProxyNetworkActivate:
 | 
			
		||||
    """
 | 
			
		||||
    这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
 | 
			
		||||
    """
 | 
			
		||||
@ -813,38 +894,48 @@ class ProxyNetworkActivate():
 | 
			
		||||
        else:
 | 
			
		||||
            # 给定了task, 我们检查一下
 | 
			
		||||
            from toolbox import get_conf
 | 
			
		||||
            WHEN_TO_USE_PROXY = get_conf('WHEN_TO_USE_PROXY')
 | 
			
		||||
            self.valid = (task in WHEN_TO_USE_PROXY)
 | 
			
		||||
 | 
			
		||||
            WHEN_TO_USE_PROXY = get_conf("WHEN_TO_USE_PROXY")
 | 
			
		||||
            self.valid = task in WHEN_TO_USE_PROXY
 | 
			
		||||
 | 
			
		||||
    def __enter__(self):
 | 
			
		||||
        if not self.valid: return self
 | 
			
		||||
        if not self.valid:
 | 
			
		||||
            return self
 | 
			
		||||
        from toolbox import get_conf
 | 
			
		||||
        proxies = get_conf('proxies')
 | 
			
		||||
        if 'no_proxy' in os.environ: os.environ.pop('no_proxy')
 | 
			
		||||
 | 
			
		||||
        proxies = get_conf("proxies")
 | 
			
		||||
        if "no_proxy" in os.environ:
 | 
			
		||||
            os.environ.pop("no_proxy")
 | 
			
		||||
        if proxies is not None:
 | 
			
		||||
            if 'http' in proxies: os.environ['HTTP_PROXY'] = proxies['http']
 | 
			
		||||
            if 'https' in proxies: os.environ['HTTPS_PROXY'] = proxies['https']
 | 
			
		||||
            if "http" in proxies:
 | 
			
		||||
                os.environ["HTTP_PROXY"] = proxies["http"]
 | 
			
		||||
            if "https" in proxies:
 | 
			
		||||
                os.environ["HTTPS_PROXY"] = proxies["https"]
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __exit__(self, exc_type, exc_value, traceback):
 | 
			
		||||
        os.environ['no_proxy'] = '*'
 | 
			
		||||
        if 'HTTP_PROXY' in os.environ: os.environ.pop('HTTP_PROXY')
 | 
			
		||||
        if 'HTTPS_PROXY' in os.environ: os.environ.pop('HTTPS_PROXY')
 | 
			
		||||
        os.environ["no_proxy"] = "*"
 | 
			
		||||
        if "HTTP_PROXY" in os.environ:
 | 
			
		||||
            os.environ.pop("HTTP_PROXY")
 | 
			
		||||
        if "HTTPS_PROXY" in os.environ:
 | 
			
		||||
            os.environ.pop("HTTPS_PROXY")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def objdump(obj, file='objdump.tmp'):
 | 
			
		||||
def objdump(obj, file="objdump.tmp"):
 | 
			
		||||
    import pickle
 | 
			
		||||
    with open(file, 'wb+') as f:
 | 
			
		||||
 | 
			
		||||
    with open(file, "wb+") as f:
 | 
			
		||||
        pickle.dump(obj, f)
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def objload(file='objdump.tmp'):
 | 
			
		||||
def objload(file="objdump.tmp"):
 | 
			
		||||
    import pickle, os
 | 
			
		||||
 | 
			
		||||
    if not os.path.exists(file):
 | 
			
		||||
        return
 | 
			
		||||
    with open(file, 'rb') as f:
 | 
			
		||||
    with open(file, "rb") as f:
 | 
			
		||||
        return pickle.load(f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -863,22 +954,25 @@ def Singleton(cls):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_pictures_list(path):
 | 
			
		||||
    file_manifest = [f for f in glob.glob(f'{path}/**/*.jpg', recursive=True)]
 | 
			
		||||
    file_manifest += [f for f in glob.glob(f'{path}/**/*.jpeg', recursive=True)]
 | 
			
		||||
    file_manifest += [f for f in glob.glob(f'{path}/**/*.png', recursive=True)]
 | 
			
		||||
    file_manifest = [f for f in glob.glob(f"{path}/**/*.jpg", recursive=True)]
 | 
			
		||||
    file_manifest += [f for f in glob.glob(f"{path}/**/*.jpeg", recursive=True)]
 | 
			
		||||
    file_manifest += [f for f in glob.glob(f"{path}/**/*.png", recursive=True)]
 | 
			
		||||
    return file_manifest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def have_any_recent_upload_image_files(chatbot):
 | 
			
		||||
    _5min = 5 * 60
 | 
			
		||||
    if chatbot is None: return False, None  # chatbot is None
 | 
			
		||||
    if chatbot is None:
 | 
			
		||||
        return False, None  # chatbot is None
 | 
			
		||||
    most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
 | 
			
		||||
    if not most_recent_uploaded: return False, None  # most_recent_uploaded is None
 | 
			
		||||
    if not most_recent_uploaded:
 | 
			
		||||
        return False, None  # most_recent_uploaded is None
 | 
			
		||||
    if time.time() - most_recent_uploaded["time"] < _5min:
 | 
			
		||||
        most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
 | 
			
		||||
        path = most_recent_uploaded['path']
 | 
			
		||||
        path = most_recent_uploaded["path"]
 | 
			
		||||
        file_manifest = get_pictures_list(path)
 | 
			
		||||
        if len(file_manifest) == 0: return False, None
 | 
			
		||||
        if len(file_manifest) == 0:
 | 
			
		||||
            return False, None
 | 
			
		||||
        return True, file_manifest  # most_recent_uploaded is new
 | 
			
		||||
    else:
 | 
			
		||||
        return False, None  # most_recent_uploaded is too old
 | 
			
		||||
@ -887,16 +981,19 @@ def have_any_recent_upload_image_files(chatbot):
 | 
			
		||||
# Function to encode the image
 | 
			
		||||
def encode_image(image_path):
 | 
			
		||||
    with open(image_path, "rb") as image_file:
 | 
			
		||||
        return base64.b64encode(image_file.read()).decode('utf-8')
 | 
			
		||||
        return base64.b64encode(image_file.read()).decode("utf-8")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_max_token(llm_kwargs):
 | 
			
		||||
    from request_llms.bridge_all import model_info
 | 
			
		||||
    return model_info[llm_kwargs['llm_model']]['max_token']
 | 
			
		||||
 | 
			
		||||
    return model_info[llm_kwargs["llm_model"]]["max_token"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_packages(packages=[]):
 | 
			
		||||
    import importlib.util
 | 
			
		||||
 | 
			
		||||
    for p in packages:
 | 
			
		||||
        spam_spec = importlib.util.find_spec(p)
 | 
			
		||||
        if spam_spec is None: raise ModuleNotFoundError
 | 
			
		||||
        if spam_spec is None:
 | 
			
		||||
            raise ModuleNotFoundError
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user