From 6d1ea643e96b470b2e10124a55e7d58e68a535c3 Mon Sep 17 00:00:00 2001 From: 505030475 <505030475@qq.com> Date: Tue, 30 May 2023 12:54:42 +0800 Subject: [PATCH] langchain --- colorful.py | 80 +++++++--------------- crazy_functional.py | 12 ++++ crazy_functions/Langchain知识库.py | 88 +++++++++++++++++++++++++ crazy_functions/crazy_functions_test.py | 68 ++++++++++++++++++- crazy_functions/crazy_utils.py | 85 ++++++++++++++++++++++++ toolbox.py | 12 +++- 6 files changed, 286 insertions(+), 59 deletions(-) create mode 100644 crazy_functions/Langchain知识库.py diff --git a/colorful.py b/colorful.py index d90972b..9749861 100644 --- a/colorful.py +++ b/colorful.py @@ -34,58 +34,28 @@ def print亮紫(*kw,**kargs): def print亮靛(*kw,**kargs): print("\033[1;36m",*kw,"\033[0m",**kargs) - - -def print亮红(*kw,**kargs): - print("\033[1;31m",*kw,"\033[0m",**kargs) -def print亮绿(*kw,**kargs): - print("\033[1;32m",*kw,"\033[0m",**kargs) -def print亮黄(*kw,**kargs): - print("\033[1;33m",*kw,"\033[0m",**kargs) -def print亮蓝(*kw,**kargs): - print("\033[1;34m",*kw,"\033[0m",**kargs) -def print亮紫(*kw,**kargs): - print("\033[1;35m",*kw,"\033[0m",**kargs) -def print亮靛(*kw,**kargs): - print("\033[1;36m",*kw,"\033[0m",**kargs) - -print_red = print红 -print_green = print绿 -print_yellow = print黄 -print_blue = print蓝 -print_purple = print紫 -print_indigo = print靛 - -print_bold_red = print亮红 -print_bold_green = print亮绿 -print_bold_yellow = print亮黄 -print_bold_blue = print亮蓝 -print_bold_purple = print亮紫 -print_bold_indigo = print亮靛 - -if not stdout.isatty(): - # redirection, avoid a fucked up log file - print红 = print - print绿 = print - print黄 = print - print蓝 = print - print紫 = print - print靛 = print - print亮红 = print - print亮绿 = print - print亮黄 = print - print亮蓝 = print - print亮紫 = print - print亮靛 = print - print_red = print - print_green = print - print_yellow = print - print_blue = print - print_purple = print - print_indigo = print - print_bold_red = print - print_bold_green = print - print_bold_yellow = print - print_bold_blue = print - print_bold_purple = print - print_bold_indigo = print \ No newline at end of file +# Do you like the elegance of Chinese characters? +def sprint红(*kw): + return "\033[0;31m"+' '.join(kw)+"\033[0m" +def sprint绿(*kw): + return "\033[0;32m"+' '.join(kw)+"\033[0m" +def sprint黄(*kw): + return "\033[0;33m"+' '.join(kw)+"\033[0m" +def sprint蓝(*kw): + return "\033[0;34m"+' '.join(kw)+"\033[0m" +def sprint紫(*kw): + return "\033[0;35m"+' '.join(kw)+"\033[0m" +def sprint靛(*kw): + return "\033[0;36m"+' '.join(kw)+"\033[0m" +def sprint亮红(*kw): + return "\033[1;31m"+' '.join(kw)+"\033[0m" +def sprint亮绿(*kw): + return "\033[1;32m"+' '.join(kw)+"\033[0m" +def sprint亮黄(*kw): + return "\033[1;33m"+' '.join(kw)+"\033[0m" +def sprint亮蓝(*kw): + return "\033[1;34m"+' '.join(kw)+"\033[0m" +def sprint亮紫(*kw): + return "\033[1;35m"+' '.join(kw)+"\033[0m" +def sprint亮靛(*kw): + return "\033[1;36m"+' '.join(kw)+"\033[0m" diff --git a/crazy_functional.py b/crazy_functional.py index 91c85cf..3b295ac 100644 --- a/crazy_functional.py +++ b/crazy_functional.py @@ -295,5 +295,17 @@ def get_crazy_functions(): except: print('Load function plugin failed') + try: + from crazy_functions.Langchain知识库 import 知识库问答 + function_plugins.update({ + "构建知识库(请先上传文件素材)": { + "Color": "stop", + "AsButton": False, + "Function": HotReload(知识库问答) + } + }) + except: + print('Load function plugin failed') + ###################### 第n组插件 ########################### return function_plugins diff --git a/crazy_functions/Langchain知识库.py b/crazy_functions/Langchain知识库.py new file mode 100644 index 0000000..0bdb7f5 --- /dev/null +++ b/crazy_functions/Langchain知识库.py @@ -0,0 +1,88 @@ +from toolbox import CatchException, update_ui, ProxyNetworkActivate +from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_files_from_everything + + + +@CatchException +def 知识库问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): + """ + txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径 + llm_kwargs gpt模型参数, 如温度和top_p等, 一般原样传递下去就行 + plugin_kwargs 插件模型的参数,暂时没有用武之地 + chatbot 聊天显示框的句柄,用于显示给用户 + history 聊天历史,前情提要 + system_prompt 给gpt的静默提醒 + web_port 当前软件运行的端口号 + """ + history = [] # 清空历史,以免输入溢出 + chatbot.append(("这是什么功能?", "[Local Message] 从一批文件(txt, md, tex)中读取数据构建知识库, 然后进行问答。")) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + + try: + import zh_langchain + from langchain.embeddings.huggingface import HuggingFaceEmbeddings + from .crazy_utils import knowledge_archive_interface + except Exception as e: + chatbot.append( + ["依赖不足", + "导入依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade langchain zh_langchain```。"] + ) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + return + # < --------------------读取文件--------------- > + file_manifest = [] + spl = ["doc", "docx", "email", "epub", "html", "image", "json", "md", "msg", "odt", "pdf", "ppt", "pptx", "rtf", "text"] + for sp in spl: + _, file_manifest_tmp, _ = get_files_from_everything(txt, type=f'.{sp}') + file_manifest += file_manifest_tmp + + if len(file_manifest) == 0: + chatbot.append(["没有找到任何可读取文件", "当前支持的格式包括: txt, md, tex"]) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + return + + # < -------------------预热文本向量化模组--------------- > + chatbot.append(['
'.join(file_manifest), "正在预热文本向量化模组, 如果是第一次运行, 将消耗较长时间下载中文向量化模型..."]) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + print('Checking Text2vec ...') + from langchain.embeddings.huggingface import HuggingFaceEmbeddings + with ProxyNetworkActivate(): # 临时地激活代理网络 + HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese", model_kwargs={'device': 'cpu'}) + + # < -------------------构建知识库--------------- > + chatbot.append(['
'.join(file_manifest), "正在构建知识库..."]) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + print('Establishing knowledge archive ...') + with ProxyNetworkActivate(): # 临时地激活代理网络 + kai = knowledge_archive_interface() + kai.feed_archive(file_manifest=file_manifest, id="default") + + chatbot.append(['知识库构建成功', "正在将知识库存储至cookie中"]) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 + chatbot._cookies['langchain_plugin_embedding'] = kai.get_current_archive_id() + chatbot._cookies['lock_plugin'] = 'crazy_functions.Langchain知识库->读取知识库作答' + chatbot.append(['完成', "“根据知识库作答”函数插件已经接管问答系统, 提问吧! 但注意, 您接下来不能再使用其他插件了,刷新页面即可以退出知识库问答模式。"]) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新 + +@CatchException +def 读取知识库作答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port=-1): + + # < ------------------- --------------- > + from .crazy_utils import knowledge_archive_interface + kai = knowledge_archive_interface() + + if 'langchain_plugin_embedding' in chatbot._cookies: + resp, prompt = kai.answer_with_archive_by_id(txt, chatbot._cookies['langchain_plugin_embedding']) + else: + if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg") + kai_id = plugin_kwargs.get("advanced_arg", 'default') + resp, prompt = kai.answer_with_archive_by_id(txt, kai_id) + + chatbot.append((txt, '[Local Message] ' + prompt)) + gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive( + inputs=prompt, inputs_show_user=txt, + llm_kwargs=llm_kwargs, chatbot=chatbot, history=[], + sys_prompt=system_prompt + ) + history.extend((prompt, gpt_say)) + yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新 diff --git a/crazy_functions/crazy_functions_test.py b/crazy_functions/crazy_functions_test.py index a9bfbf8..3d9e103 100644 --- a/crazy_functions/crazy_functions_test.py +++ b/crazy_functions/crazy_functions_test.py @@ -3,6 +3,8 @@ 这个文件用于函数插件的单元测试 运行方法 python crazy_functions/crazy_functions_test.py """ + +# ============================================================================================================================== def validate_path(): import os, sys @@ -10,10 +12,16 @@ def validate_path(): root_dir_assume = os.path.abspath(os.path.dirname(__file__) + '/..') os.chdir(root_dir_assume) sys.path.append(root_dir_assume) - validate_path() # validate path so you can run from base directory + +# ============================================================================================================================== + from colorful import * from toolbox import get_conf, ChatBotWithCookies +import contextlib +import os +import sys +from functools import wraps proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY = \ get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY') @@ -30,7 +38,43 @@ history = [] system_prompt = "Serve me as a writing and programming assistant." web_port = 1024 +# ============================================================================================================================== +def silence_stdout(func): + @wraps(func) + def wrapper(*args, **kwargs): + _original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + for q in func(*args, **kwargs): + sys.stdout = _original_stdout + yield q + sys.stdout = open(os.devnull, 'w') + sys.stdout.close() + sys.stdout = _original_stdout + return wrapper + +class CLI_Printer(): + def __init__(self) -> None: + self.pre_buf = "" + + def print(self, buf): + bufp = "" + for index, chat in enumerate(buf): + a, b = chat + bufp += sprint亮靛('[Me]:' + a) + '\n' + bufp += '[GPT]:' + b + if index < len(buf)-1: + bufp += '\n' + + if self.pre_buf!="" and bufp.startswith(self.pre_buf): + print(bufp[len(self.pre_buf):], end='') + else: + print('\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n'+bufp, end='') + self.pre_buf = bufp + return + +cli_printer = CLI_Printer() +# ============================================================================================================================== def test_解析一个Python项目(): from crazy_functions.解析项目源代码 import 解析一个Python项目 txt = "crazy_functions/test_project/python/dqn" @@ -116,6 +160,25 @@ def test_Markdown多语言(): for cookies, cb, hist, msg in Markdown翻译指定语言(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): print(cb) +def test_Langchain知识库(): + from crazy_functions.Langchain知识库 import 知识库问答 + txt = "README.md" + chatbot = ChatBotWithCookies(llm_kwargs) + for cookies, cb, hist, msg in silence_stdout(知识库问答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): + cli_printer.print(cb) # print(cb) + + chatbot = ChatBotWithCookies(cookies) + from crazy_functions.Langchain知识库 import 读取知识库作答 + txt = "摘要?" + for cookies, cb, hist, msg in silence_stdout(读取知识库作答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): + cli_printer.print(cb) # print(cb) + +def test_Langchain知识库读取(): + from crazy_functions.Langchain知识库 import 读取知识库作答 + txt = "远程云服务器部署?" + for cookies, cb, hist, msg in silence_stdout(读取知识库作答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port): + cli_printer.print(cb) # print(cb) + # test_解析一个Python项目() @@ -129,7 +192,8 @@ def test_Markdown多语言(): # test_联网回答问题() # test_解析ipynb文件() # test_数学动画生成manim() -test_Markdown多语言() +test_Langchain知识库() +# test_Langchain知识库读取() input("程序完成,回车退出。") print("退出。") \ No newline at end of file diff --git a/crazy_functions/crazy_utils.py b/crazy_functions/crazy_utils.py index de205d7..3416d91 100644 --- a/crazy_functions/crazy_utils.py +++ b/crazy_functions/crazy_utils.py @@ -1,4 +1,5 @@ from toolbox import update_ui, get_conf, trimmed_format_exc +import threading def input_clipping(inputs, history, max_token_limit): import numpy as np @@ -606,3 +607,87 @@ def get_files_from_everything(txt, type): # type='.md' success = False return success, file_manifest, project_folder + + + + +def Singleton(cls): + _instance = {} + + def _singleton(*args, **kargs): + if cls not in _instance: + _instance[cls] = cls(*args, **kargs) + return _instance[cls] + + return _singleton + + +@Singleton +class knowledge_archive_interface(): + def __init__(self) -> None: + self.threadLock = threading.Lock() + self.current_id = "" + self.kai_path = None + self.qa_handle = None + self.text2vec_large_chinese = None + + def get_chinese_text2vec(self): + if self.text2vec_large_chinese is None: + # < -------------------预热文本向量化模组--------------- > + from toolbox import ProxyNetworkActivate + print('Checking Text2vec ...') + from langchain.embeddings.huggingface import HuggingFaceEmbeddings + with ProxyNetworkActivate(): # 临时地激活代理网络 + self.text2vec_large_chinese = HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese", model_kwargs={'device': 'cpu'}) + + return self.text2vec_large_chinese + + + def feed_archive(self, file_manifest, id="default"): + self.threadLock.acquire() + # import uuid + self.current_id = id + from zh_langchain import construct_vector_store + self.qa_handle, self.kai_path = construct_vector_store( + vs_id=self.current_id, + files=file_manifest, + sentence_size=100, + history=[], + one_conent="", + one_content_segmentation="", + text2vec = self.get_chinese_text2vec(), + ) + self.threadLock.release() + + def get_current_archive_id(self): + return self.current_id + + def answer_with_archive_by_id(self, txt, id): + self.threadLock.acquire() + if not self.current_id == id: + self.current_id = id + from zh_langchain import construct_vector_store + self.qa_handle, self.kai_path = construct_vector_store( + vs_id=self.current_id, + files=[], + sentence_size=100, + history=[], + one_conent="", + one_content_segmentation="", + text2vec = self.get_chinese_text2vec(), + ) + VECTOR_SEARCH_SCORE_THRESHOLD = 0 + VECTOR_SEARCH_TOP_K = 4 + CHUNK_SIZE = 512 + resp, prompt = self.qa_handle.get_knowledge_based_conent_test( + query = txt, + vs_path = self.kai_path, + score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, + vector_search_top_k=VECTOR_SEARCH_TOP_K, + chunk_conent=True, + chunk_size=CHUNK_SIZE, + text2vec = self.get_chinese_text2vec(), + ) + self.threadLock.release() + return resp, prompt + diff --git a/toolbox.py b/toolbox.py index 10e5a87..6903684 100644 --- a/toolbox.py +++ b/toolbox.py @@ -59,7 +59,15 @@ def ArgsGeneralWrapper(f): } chatbot_with_cookie = ChatBotWithCookies(cookies) chatbot_with_cookie.write_list(chatbot) - yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args) + + if 'lock_plugin' in cookies and cookies['lock_plugin'] is not None: + # 处理插件锁定状态 + module, fn_name = cookies['lock_plugin'].split('->') + f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name) + yield from HotReload(f_hot_reload)(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args) + else: + # 正常状态 + yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args) return decorated @@ -83,7 +91,7 @@ def CatchException(f): """ @wraps(f) - def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT): + def decorated(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT=-1): try: yield from f(txt, top_p, temperature, chatbot, history, systemPromptTxt, WEB_PORT) except Exception as e: