diff --git a/colorful.py b/colorful.py
index d90972b..9749861 100644
--- a/colorful.py
+++ b/colorful.py
@@ -34,58 +34,28 @@ def print亮紫(*kw,**kargs):
def print亮靛(*kw,**kargs):
print("\033[1;36m",*kw,"\033[0m",**kargs)
-
-
-def print亮红(*kw,**kargs):
- print("\033[1;31m",*kw,"\033[0m",**kargs)
-def print亮绿(*kw,**kargs):
- print("\033[1;32m",*kw,"\033[0m",**kargs)
-def print亮黄(*kw,**kargs):
- print("\033[1;33m",*kw,"\033[0m",**kargs)
-def print亮蓝(*kw,**kargs):
- print("\033[1;34m",*kw,"\033[0m",**kargs)
-def print亮紫(*kw,**kargs):
- print("\033[1;35m",*kw,"\033[0m",**kargs)
-def print亮靛(*kw,**kargs):
- print("\033[1;36m",*kw,"\033[0m",**kargs)
-
-print_red = print红
-print_green = print绿
-print_yellow = print黄
-print_blue = print蓝
-print_purple = print紫
-print_indigo = print靛
-
-print_bold_red = print亮红
-print_bold_green = print亮绿
-print_bold_yellow = print亮黄
-print_bold_blue = print亮蓝
-print_bold_purple = print亮紫
-print_bold_indigo = print亮靛
-
-if not stdout.isatty():
- # redirection, avoid a fucked up log file
- print红 = print
- print绿 = print
- print黄 = print
- print蓝 = print
- print紫 = print
- print靛 = print
- print亮红 = print
- print亮绿 = print
- print亮黄 = print
- print亮蓝 = print
- print亮紫 = print
- print亮靛 = print
- print_red = print
- print_green = print
- print_yellow = print
- print_blue = print
- print_purple = print
- print_indigo = print
- print_bold_red = print
- print_bold_green = print
- print_bold_yellow = print
- print_bold_blue = print
- print_bold_purple = print
- print_bold_indigo = print
\ No newline at end of file
+# Do you like the elegance of Chinese characters?
+def sprint红(*kw):
+ return "\033[0;31m"+' '.join(kw)+"\033[0m"
+def sprint绿(*kw):
+ return "\033[0;32m"+' '.join(kw)+"\033[0m"
+def sprint黄(*kw):
+ return "\033[0;33m"+' '.join(kw)+"\033[0m"
+def sprint蓝(*kw):
+ return "\033[0;34m"+' '.join(kw)+"\033[0m"
+def sprint紫(*kw):
+ return "\033[0;35m"+' '.join(kw)+"\033[0m"
+def sprint靛(*kw):
+ return "\033[0;36m"+' '.join(kw)+"\033[0m"
+def sprint亮红(*kw):
+ return "\033[1;31m"+' '.join(kw)+"\033[0m"
+def sprint亮绿(*kw):
+ return "\033[1;32m"+' '.join(kw)+"\033[0m"
+def sprint亮黄(*kw):
+ return "\033[1;33m"+' '.join(kw)+"\033[0m"
+def sprint亮蓝(*kw):
+ return "\033[1;34m"+' '.join(kw)+"\033[0m"
+def sprint亮紫(*kw):
+ return "\033[1;35m"+' '.join(kw)+"\033[0m"
+def sprint亮靛(*kw):
+ return "\033[1;36m"+' '.join(kw)+"\033[0m"
diff --git a/crazy_functional.py b/crazy_functional.py
index 91c85cf..3b295ac 100644
--- a/crazy_functional.py
+++ b/crazy_functional.py
@@ -295,5 +295,17 @@ def get_crazy_functions():
except:
print('Load function plugin failed')
+ try:
+ from crazy_functions.Langchain知识库 import 知识库问答
+ function_plugins.update({
+ "构建知识库(请先上传文件素材)": {
+ "Color": "stop",
+ "AsButton": False,
+ "Function": HotReload(知识库问答)
+ }
+ })
+ except:
+ print('Load function plugin failed')
+
###################### 第n组插件 ###########################
return function_plugins
diff --git a/crazy_functions/Langchain知识库.py b/crazy_functions/Langchain知识库.py
new file mode 100644
index 0000000..0bdb7f5
--- /dev/null
+++ b/crazy_functions/Langchain知识库.py
@@ -0,0 +1,88 @@
+from toolbox import CatchException, update_ui, ProxyNetworkActivate
+from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_files_from_everything
+
+
+
+@CatchException
+def 知识库问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
+ """
+ txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
+ llm_kwargs gpt模型参数, 如温度和top_p等, 一般原样传递下去就行
+ plugin_kwargs 插件模型的参数,暂时没有用武之地
+ chatbot 聊天显示框的句柄,用于显示给用户
+ history 聊天历史,前情提要
+ system_prompt 给gpt的静默提醒
+ web_port 当前软件运行的端口号
+ """
+ history = [] # 清空历史,以免输入溢出
+ chatbot.append(("这是什么功能?", "[Local Message] 从一批文件(txt, md, tex)中读取数据构建知识库, 然后进行问答。"))
+ yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
+
+ try:
+ import zh_langchain
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+ from .crazy_utils import knowledge_archive_interface
+ except Exception as e:
+ chatbot.append(
+ ["依赖不足",
+ "导入依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade langchain zh_langchain```。"]
+ )
+ yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
+ return
+ # < --------------------读取文件--------------- >
+ file_manifest = []
+ spl = ["doc", "docx", "email", "epub", "html", "image", "json", "md", "msg", "odt", "pdf", "ppt", "pptx", "rtf", "text"]
+ for sp in spl:
+ _, file_manifest_tmp, _ = get_files_from_everything(txt, type=f'.{sp}')
+ file_manifest += file_manifest_tmp
+
+ if len(file_manifest) == 0:
+ chatbot.append(["没有找到任何可读取文件", "当前支持的格式包括: txt, md, tex"])
+ yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
+ return
+
+ # < -------------------预热文本向量化模组--------------- >
+ chatbot.append(['
'.join(file_manifest), "正在预热文本向量化模组, 如果是第一次运行, 将消耗较长时间下载中文向量化模型..."])
+ yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
+ print('Checking Text2vec ...')
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+ with ProxyNetworkActivate(): # 临时地激活代理网络
+ HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese", model_kwargs={'device': 'cpu'})
+
+ # < -------------------构建知识库--------------- >
+ chatbot.append(['
'.join(file_manifest), "正在构建知识库..."])
+ yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
+ print('Establishing knowledge archive ...')
+ with ProxyNetworkActivate(): # 临时地激活代理网络
+ kai = knowledge_archive_interface()
+ kai.feed_archive(file_manifest=file_manifest, id="default")
+
+ chatbot.append(['知识库构建成功', "正在将知识库存储至cookie中"])
+ yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
+ chatbot._cookies['langchain_plugin_embedding'] = kai.get_current_archive_id()
+ chatbot._cookies['lock_plugin'] = 'crazy_functions.Langchain知识库->读取知识库作答'
+ chatbot.append(['完成', "“根据知识库作答”函数插件已经接管问答系统, 提问吧! 但注意, 您接下来不能再使用其他插件了,刷新页面即可以退出知识库问答模式。"])
+ yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
+
+@CatchException
+def 读取知识库作答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port=-1):
+
+ # < ------------------- --------------- >
+ from .crazy_utils import knowledge_archive_interface
+ kai = knowledge_archive_interface()
+
+ if 'langchain_plugin_embedding' in chatbot._cookies:
+ resp, prompt = kai.answer_with_archive_by_id(txt, chatbot._cookies['langchain_plugin_embedding'])
+ else:
+ if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
+ kai_id = plugin_kwargs.get("advanced_arg", 'default')
+ resp, prompt = kai.answer_with_archive_by_id(txt, kai_id)
+
+ chatbot.append((txt, '[Local Message] ' + prompt))
+ gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
+ inputs=prompt, inputs_show_user=txt,
+ llm_kwargs=llm_kwargs, chatbot=chatbot, history=[],
+ sys_prompt=system_prompt
+ )
+ history.extend((prompt, gpt_say))
+ yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
diff --git a/crazy_functions/crazy_functions_test.py b/crazy_functions/crazy_functions_test.py
index a9bfbf8..3d9e103 100644
--- a/crazy_functions/crazy_functions_test.py
+++ b/crazy_functions/crazy_functions_test.py
@@ -3,6 +3,8 @@
这个文件用于函数插件的单元测试
运行方法 python crazy_functions/crazy_functions_test.py
"""
+
+# ==============================================================================================================================
def validate_path():
import os, sys
@@ -10,10 +12,16 @@ def validate_path():
root_dir_assume = os.path.abspath(os.path.dirname(__file__) + '/..')
os.chdir(root_dir_assume)
sys.path.append(root_dir_assume)
-
validate_path() # validate path so you can run from base directory
+
+# ==============================================================================================================================
+
from colorful import *
from toolbox import get_conf, ChatBotWithCookies
+import contextlib
+import os
+import sys
+from functools import wraps
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY = \
get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY')
@@ -30,7 +38,43 @@ history = []
system_prompt = "Serve me as a writing and programming assistant."
web_port = 1024
+# ==============================================================================================================================
+def silence_stdout(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ _original_stdout = sys.stdout
+ sys.stdout = open(os.devnull, 'w')
+ for q in func(*args, **kwargs):
+ sys.stdout = _original_stdout
+ yield q
+ sys.stdout = open(os.devnull, 'w')
+ sys.stdout.close()
+ sys.stdout = _original_stdout
+ return wrapper
+
+class CLI_Printer():
+ def __init__(self) -> None:
+ self.pre_buf = ""
+
+ def print(self, buf):
+ bufp = ""
+ for index, chat in enumerate(buf):
+ a, b = chat
+ bufp += sprint亮靛('[Me]:' + a) + '\n'
+ bufp += '[GPT]:' + b
+ if index < len(buf)-1:
+ bufp += '\n'
+
+ if self.pre_buf!="" and bufp.startswith(self.pre_buf):
+ print(bufp[len(self.pre_buf):], end='')
+ else:
+ print('\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n'+bufp, end='')
+ self.pre_buf = bufp
+ return
+
+cli_printer = CLI_Printer()
+# ==============================================================================================================================
def test_解析一个Python项目():
from crazy_functions.解析项目源代码 import 解析一个Python项目
txt = "crazy_functions/test_project/python/dqn"
@@ -116,6 +160,25 @@ def test_Markdown多语言():
for cookies, cb, hist, msg in Markdown翻译指定语言(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
+def test_Langchain知识库():
+ from crazy_functions.Langchain知识库 import 知识库问答
+ txt = "README.md"
+ chatbot = ChatBotWithCookies(llm_kwargs)
+ for cookies, cb, hist, msg in silence_stdout(知识库问答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
+ cli_printer.print(cb) # print(cb)
+
+ chatbot = ChatBotWithCookies(cookies)
+ from crazy_functions.Langchain知识库 import 读取知识库作答
+ txt = "摘要?"
+ for cookies, cb, hist, msg in silence_stdout(读取知识库作答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
+ cli_printer.print(cb) # print(cb)
+
+def test_Langchain知识库读取():
+ from crazy_functions.Langchain知识库 import 读取知识库作答
+ txt = "远程云服务器部署?"
+ for cookies, cb, hist, msg in silence_stdout(读取知识库作答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
+ cli_printer.print(cb) # print(cb)
+
# test_解析一个Python项目()
@@ -129,7 +192,8 @@ def test_Markdown多语言():
# test_联网回答问题()
# test_解析ipynb文件()
# test_数学动画生成manim()
-test_Markdown多语言()
+test_Langchain知识库()
+# test_Langchain知识库读取()
input("程序完成,回车退出。")
print("退出。")
\ No newline at end of file
diff --git a/crazy_functions/crazy_utils.py b/crazy_functions/crazy_utils.py
index de205d7..3416d91 100644
--- a/crazy_functions/crazy_utils.py
+++ b/crazy_functions/crazy_utils.py
@@ -1,4 +1,5 @@
from toolbox import update_ui, get_conf, trimmed_format_exc
+import threading
def input_clipping(inputs, history, max_token_limit):
import numpy as np
@@ -606,3 +607,87 @@ def get_files_from_everything(txt, type): # type='.md'
success = False
return success, file_manifest, project_folder
+
+
+
+
+def Singleton(cls):
+ _instance = {}
+
+ def _singleton(*args, **kargs):
+ if cls not in _instance:
+ _instance[cls] = cls(*args, **kargs)
+ return _instance[cls]
+
+ return _singleton
+
+
+@Singleton
+class knowledge_archive_interface():
+ def __init__(self) -> None:
+ self.threadLock = threading.Lock()
+ self.current_id = ""
+ self.kai_path = None
+ self.qa_handle = None
+ self.text2vec_large_chinese = None
+
+ def get_chinese_text2vec(self):
+ if self.text2vec_large_chinese is None:
+ # < -------------------预热文本向量化模组--------------- >
+ from toolbox import ProxyNetworkActivate
+ print('Checking Text2vec ...')
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
+ with ProxyNetworkActivate(): # 临时地激活代理网络
+ self.text2vec_large_chinese = HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese", model_kwargs={'device': 'cpu'})
+
+ return self.text2vec_large_chinese
+
+
+ def feed_archive(self, file_manifest, id="default"):
+ self.threadLock.acquire()
+ # import uuid
+ self.current_id = id
+ from zh_langchain import construct_vector_store
+ self.qa_handle, self.kai_path = construct_vector_store(
+ vs_id=self.current_id,
+ files=file_manifest,
+ sentence_size=100,
+ history=[],
+ one_conent="",
+ one_content_segmentation="",
+ text2vec = self.get_chinese_text2vec(),
+ )
+ self.threadLock.release()
+
+ def get_current_archive_id(self):
+ return self.current_id
+
+ def answer_with_archive_by_id(self, txt, id):
+ self.threadLock.acquire()
+ if not self.current_id == id:
+ self.current_id = id
+ from zh_langchain import construct_vector_store
+ self.qa_handle, self.kai_path = construct_vector_store(
+ vs_id=self.current_id,
+ files=[],
+ sentence_size=100,
+ history=[],
+ one_conent="",
+ one_content_segmentation="",
+ text2vec = self.get_chinese_text2vec(),
+ )
+ VECTOR_SEARCH_SCORE_THRESHOLD = 0
+ VECTOR_SEARCH_TOP_K = 4
+ CHUNK_SIZE = 512
+ resp, prompt = self.qa_handle.get_knowledge_based_conent_test(
+ query = txt,
+ vs_path = self.kai_path,
+ score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
+ vector_search_top_k=VECTOR_SEARCH_TOP_K,
+ chunk_conent=True,
+ chunk_size=CHUNK_SIZE,
+ text2vec = self.get_chinese_text2vec(),
+ )
+ self.threadLock.release()
+ return resp, prompt
+
diff --git a/toolbox.py b/toolbox.py
index 10e5a87..6903684 100644
--- a/toolbox.py
+++ b/toolbox.py
@@ -59,7 +59,15 @@ def ArgsGeneralWrapper(f):
}
chatbot_with_cookie = ChatBotWithCookies(cookies)
chatbot_with_cookie.write_list(chatbot)
- yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args)
+
+ if 'lock_plugin' in cookies and cookies['lock_plugin'] is not None:
+ # 处理插件锁定状态
+ module, fn_name = cookies['lock_plugin'].split('->')
+ f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
+ yield from HotReload(f_hot_reload)(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args)
+ else:
+ # 正常状态
+ yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args)
return decorated
@@ -83,7 +91,7 @@ def CatchException(f):
"""
@wraps(f)
- def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT):
+ def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT=-1):
try:
yield from f(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT)
except Exception as e: