From 14de282302ebfbc024f34e6bedfcc2a9a2047efc Mon Sep 17 00:00:00 2001 From: binary-husky Date: Wed, 13 Sep 2023 23:21:00 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=99nougat=E5=8A=A0=E7=BA=BF=E7=A8=8B?= =?UTF-8?q?=E9=94=81=20=E5=90=88=E5=B9=B6=E5=86=97=E4=BD=99=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crazy_functions/crazy_utils.py | 68 ++++++++++++++++----- crazy_functions/批量翻译PDF文档_NOUGAT.py | 74 +---------------------- crazy_functions/批量翻译PDF文档_多线程.py | 48 +-------------- 3 files changed, 57 insertions(+), 133 deletions(-) diff --git a/crazy_functions/crazy_utils.py b/crazy_functions/crazy_utils.py index 9c77b8a..6cf471e 100644 --- a/crazy_functions/crazy_utils.py +++ b/crazy_functions/crazy_utils.py @@ -1,5 +1,6 @@ -from toolbox import update_ui, get_conf, trimmed_format_exc +from toolbox import update_ui, get_conf, trimmed_format_exc, get_log_folder import threading +import os def input_clipping(inputs, history, max_token_limit): import numpy as np @@ -705,6 +706,40 @@ class knowledge_archive_interface(): ) self.threadLock.release() return resp, prompt + +@Singleton +class nougat_interface(): + def __init__(self): + self.threadLock = threading.Lock() + + def nougat_with_timeout(self, command, cwd, timeout=3600): + import subprocess + process = subprocess.Popen(command, shell=True, cwd=cwd) + try: + stdout, stderr = process.communicate(timeout=timeout) + except subprocess.TimeoutExpired: + process.kill() + stdout, stderr = process.communicate() + print("Process timed out!") + return False + return True + + + def NOUGAT_parse_pdf(self, fp): + self.threadLock.acquire() + import glob, threading, os + from toolbox import get_log_folder, gen_time_str + dst = os.path.join(get_log_folder(plugin_name='nougat'), gen_time_str()) + os.makedirs(dst) + self.nougat_with_timeout(f'nougat --out "{os.path.abspath(dst)}" "{os.path.abspath(fp)}"', os.getcwd()) + res = glob.glob(os.path.join(dst,'*.mmd')) + if len(res) == 0: + raise RuntimeError("Nougat解析论文失败。") + self.threadLock.release() + return res[0] + + + def try_install_deps(deps, reload_m=[]): import subprocess, sys, importlib @@ -715,42 +750,43 @@ def try_install_deps(deps, reload_m=[]): for m in reload_m: importlib.reload(__import__(m)) -class construct_html(): - def __init__(self) -> None: - self.css = """ + +HTML_CSS = """ .row { display: flex; flex-wrap: wrap; } - .column { flex: 1; padding: 10px; } - .table-header { font-weight: bold; border-bottom: 1px solid black; } - .table-row { border-bottom: 1px solid lightgray; } - .table-cell { padding: 5px; } - """ - self.html_string = f'翻译结果' +""" - - def add_row(self, a, b): - tmp = """ +TABLE_CSS = """
REPLACE_A
REPLACE_B
- """ +""" + +class construct_html(): + def __init__(self) -> None: + self.css = HTML_CSS + self.html_string = f'翻译结果' + + + def add_row(self, a, b): + tmp = TABLE_CSS from toolbox import markdown_convertion tmp = tmp.replace('REPLACE_A', markdown_convertion(a)) tmp = tmp.replace('REPLACE_B', markdown_convertion(b)) @@ -758,6 +794,6 @@ class construct_html(): def save_file(self, file_name): - with open(f'./gpt_log/{file_name}', 'w', encoding='utf8') as f: + with open(os.path.join(get_log_folder(), file_name), 'w', encoding='utf8') as f: f.write(self.html_string.encode('utf-8', 'ignore').decode()) - + return os.path.join(get_log_folder(), file_name) diff --git a/crazy_functions/批量翻译PDF文档_NOUGAT.py b/crazy_functions/批量翻译PDF文档_NOUGAT.py index ed15121..0e2fb81 100644 --- a/crazy_functions/批量翻译PDF文档_NOUGAT.py +++ b/crazy_functions/批量翻译PDF文档_NOUGAT.py @@ -86,31 +86,8 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst # 开始正式执行任务 yield from 解析PDF_基于NOUGAT(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt) - -def nougat_with_timeout(command, cwd, timeout=3600): - import subprocess - process = subprocess.Popen(command, shell=True, cwd=cwd) - try: - stdout, stderr = process.communicate(timeout=timeout) - except subprocess.TimeoutExpired: - process.kill() - stdout, stderr = process.communicate() - print("Process timed out!") - return False - return True -def NOUGAT_parse_pdf(fp): - import glob - from toolbox import get_log_folder, gen_time_str - dst = os.path.join(get_log_folder(plugin_name='nougat'), gen_time_str()) - os.makedirs(dst) - nougat_with_timeout(f'nougat --out "{os.path.abspath(dst)}" "{os.path.abspath(fp)}"', os.getcwd()) - res = glob.glob(os.path.join(dst,'*.mmd')) - if len(res) == 0: - raise RuntimeError("Nougat解析论文失败。") - return res[0] - def 解析PDF_基于NOUGAT(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt): import copy @@ -119,9 +96,11 @@ def 解析PDF_基于NOUGAT(file_manifest, project_folder, llm_kwargs, plugin_kwa generated_conclusion_files = [] generated_html_files = [] DST_LANG = "中文" + from crazy_functions.crazy_utils import nougat_interface, construct_html + nougat_handle = nougat_interface() for index, fp in enumerate(file_manifest): chatbot.append(["当前进度:", f"正在解析论文,请稍候。(第一次运行时,需要花费较长时间下载NOUGAT参数)"]); yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 - fpp = NOUGAT_parse_pdf(fp) + fpp = nougat_handle.NOUGAT_parse_pdf(fp) with open(fpp, 'r', encoding='utf8') as f: article_content = f.readlines() @@ -222,50 +201,3 @@ def 解析PDF_基于NOUGAT(file_manifest, project_folder, llm_kwargs, plugin_kwa yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 - -class construct_html(): - def __init__(self) -> None: - self.css = """ -.row { - display: flex; - flex-wrap: wrap; -} - -.column { - flex: 1; - padding: 10px; -} - -.table-header { - font-weight: bold; - border-bottom: 1px solid black; -} - -.table-row { - border-bottom: 1px solid lightgray; -} - -.table-cell { - padding: 5px; -} - """ - self.html_string = f'翻译结果' - - - def add_row(self, a, b): - tmp = """ -
-
REPLACE_A
-
REPLACE_B
-
- """ - from toolbox import markdown_convertion - tmp = tmp.replace('REPLACE_A', markdown_convertion(a)) - tmp = tmp.replace('REPLACE_B', markdown_convertion(b)) - self.html_string += tmp - - - def save_file(self, file_name): - with open(os.path.join(get_log_folder(), file_name), 'w', encoding='utf8') as f: - f.write(self.html_string.encode('utf-8', 'ignore').decode()) - return os.path.join(get_log_folder(), file_name) diff --git a/crazy_functions/批量翻译PDF文档_多线程.py b/crazy_functions/批量翻译PDF文档_多线程.py index 6e9fe6a..440004e 100644 --- a/crazy_functions/批量翻译PDF文档_多线程.py +++ b/crazy_functions/批量翻译PDF文档_多线程.py @@ -63,6 +63,7 @@ def 解析PDF_基于GROBID(file_manifest, project_folder, llm_kwargs, plugin_kwa generated_conclusion_files = [] generated_html_files = [] DST_LANG = "中文" + from crazy_functions.crazy_utils import construct_html for index, fp in enumerate(file_manifest): chatbot.append(["当前进度:", f"正在连接GROBID服务,请稍候: {grobid_url}\n如果等待时间过长,请修改config中的GROBID_URL,可修改成本地GROBID服务。"]); yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 article_dict = parse_pdf(fp, grobid_url) @@ -166,6 +167,7 @@ def 解析PDF(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, TOKEN_LIMIT_PER_FRAGMENT = 1280 generated_conclusion_files = [] generated_html_files = [] + from crazy_functions.crazy_utils import construct_html for index, fp in enumerate(file_manifest): # 读取PDF文件 file_content, page_one = read_and_clean_pdf_text(fp) @@ -261,49 +263,3 @@ def 解析PDF(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 -class construct_html(): - def __init__(self) -> None: - self.css = """ -.row { - display: flex; - flex-wrap: wrap; -} - -.column { - flex: 1; - padding: 10px; -} - -.table-header { - font-weight: bold; - border-bottom: 1px solid black; -} - -.table-row { - border-bottom: 1px solid lightgray; -} - -.table-cell { - padding: 5px; -} - """ - self.html_string = f'翻译结果' - - - def add_row(self, a, b): - tmp = """ -
-
REPLACE_A
-
REPLACE_B
-
- """ - from toolbox import markdown_convertion - tmp = tmp.replace('REPLACE_A', markdown_convertion(a)) - tmp = tmp.replace('REPLACE_B', markdown_convertion(b)) - self.html_string += tmp - - - def save_file(self, file_name): - with open(os.path.join(get_log_folder(), file_name), 'w', encoding='utf8') as f: - f.write(self.html_string.encode('utf-8', 'ignore').decode()) - return os.path.join(get_log_folder(), file_name)