修复插件导入时的pytorch加载问题

This commit is contained in:
binary-husky 2023-11-13 00:15:15 +08:00
parent 7e56ace2c0
commit b9b7bf38ab
6 changed files with 45 additions and 24 deletions

View File

@ -1,4 +1,5 @@
from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
from toolbox import trimmed_format_exc
def get_crazy_functions():
@ -292,6 +293,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -316,6 +318,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -331,6 +334,7 @@ def get_crazy_functions():
},
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -346,19 +350,20 @@ def get_crazy_functions():
},
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
from crazy_functions.图片生成 import 图片生成, 图片生成_DALLE3
from crazy_functions.图片生成 import 图片生成_DALLE2, 图片生成_DALLE3
function_plugins.update({
"图片生成先切换模型到openai或api2d": {
"图片生成_DALLE2 先切换模型到openai或api2d": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True, # 调用时唤起高级参数输入区默认False
"ArgsReminder": "在这里输入分辨率, 如1024x1024默认支持 256x256, 512x512, 1024x1024", # 高级参数输入区的显示提示
"Info": "使用DALLE2生成图片 | 输入参数字符串,提供图像的内容",
"Function": HotReload(图片生成)
"Function": HotReload(图片生成_DALLE2)
},
})
function_plugins.update({
@ -373,6 +378,7 @@ def get_crazy_functions():
},
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -389,6 +395,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -403,6 +410,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -418,6 +426,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -433,6 +442,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -448,6 +458,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -461,6 +472,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -505,6 +517,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -522,6 +535,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -535,6 +549,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
@ -548,8 +563,10 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
try:
from crazy_functions.多智能体 import 多智能体终端
function_plugins.update({
"AutoGen多智能体终端仅供测试": {
@ -559,6 +576,9 @@ def get_crazy_functions():
"Function": HotReload(多智能体终端)
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')
# try:
# from crazy_functions.chatglm微调工具 import 微调数据集生成

View File

@ -1,4 +1,4 @@
from toolbox import CatchException, report_exception, get_log_folder, gen_time_str
from toolbox import CatchException, report_exception, get_log_folder, gen_time_str, check_packages
from toolbox import update_ui, promote_file_to_downloadzone, update_ui_lastest_msg, disable_auto_promotion
from toolbox import write_history_to_file, promote_file_to_downloadzone
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
@ -6,9 +6,8 @@ from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_an
from .crazy_utils import read_and_clean_pdf_text
from .pdf_fns.parse_pdf import parse_pdf, get_avail_grobid_url, translate_pdf
from colorful import *
import copy
import os
import math
@CatchException
def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
@ -22,9 +21,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
# 尝试导入依赖,如果缺少依赖,则给出安装建议
try:
import fitz
import tiktoken
import scipdf
check_packages(["fitz", "tiktoken", "scipdf"])
except:
report_exception(chatbot, history,
a=f"解析项目: {txt}",

View File

@ -2,7 +2,6 @@ model_name = "ChatGLM"
cmd_to_install = "`pip install -r request_llms/requirements_chatglm.txt`"
from transformers import AutoModel, AutoTokenizer
from toolbox import get_conf, ProxyNetworkActivate
from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns
@ -23,6 +22,7 @@ class GetGLM2Handle(LocalLLMHandle):
import os, glob
import os
import platform
from transformers import AutoModel, AutoTokenizer
LOCAL_MODEL_QUANT, device = get_conf('LOCAL_MODEL_QUANT', 'LOCAL_MODEL_DEVICE')
if LOCAL_MODEL_QUANT == "INT4": # INT4

View File

@ -2,7 +2,6 @@ model_name = "ChatGLM3"
cmd_to_install = "`pip install -r request_llms/requirements_chatglm.txt`"
from transformers import AutoModel, AutoTokenizer
from toolbox import get_conf, ProxyNetworkActivate
from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns
@ -20,6 +19,7 @@ class GetGLM3Handle(LocalLLMHandle):
def load_model_and_tokenizer(self):
# 🏃‍♂️🏃‍♂️🏃‍♂️ 子进程执行
from transformers import AutoModel, AutoTokenizer
import os, glob
import os
import platform

View File

@ -1,8 +1,6 @@
from transformers import AutoModel, AutoTokenizer
import time
import threading
import importlib
from toolbox import update_ui, get_conf
from multiprocessing import Process, Pipe

View File

@ -1146,3 +1146,9 @@ def get_chat_default_kwargs():
def get_max_token(llm_kwargs):
from request_llms.bridge_all import model_info
return model_info[llm_kwargs['llm_model']]['max_token']
def check_packages(packages=[]):
import importlib.util
for p in packages:
spam_spec = importlib.util.find_spec(p)
if spam_spec is None: raise ModuleNotFoundError