diff --git a/crazy_functions/vector_fns/vector_database.py b/crazy_functions/vector_fns/vector_database.py index 098eb22..b256e70 100644 --- a/crazy_functions/vector_fns/vector_database.py +++ b/crazy_functions/vector_fns/vector_database.py @@ -26,10 +26,6 @@ EMBEDDING_MODEL = "text2vec" # Embedding running device EMBEDDING_DEVICE = "cpu" -VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store") - -UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content") - # 基于上下文的prompt模版,请务必保留"{question}"和"{context}" PROMPT_TEMPLATE = """已知信息: {context} @@ -159,7 +155,7 @@ class LocalDocQA: elif os.path.isfile(filepath): file = os.path.split(filepath)[-1] try: - docs = load_file(filepath, sentence_size) + docs = load_file(filepath, SENTENCE_SIZE) print(f"{file} 已成功加载") loaded_files.append(filepath) except Exception as e: @@ -171,7 +167,7 @@ class LocalDocQA: for file in tqdm(os.listdir(filepath), desc="加载文件"): fullfilepath = os.path.join(filepath, file) try: - docs += load_file(fullfilepath, sentence_size) + docs += load_file(fullfilepath, SENTENCE_SIZE) loaded_files.append(fullfilepath) except Exception as e: print(e) @@ -185,21 +181,19 @@ class LocalDocQA: else: docs = [] for file in filepath: - try: - docs += load_file(file) - print(f"{file} 已成功加载") - loaded_files.append(file) - except Exception as e: - print(e) - print(f"{file} 未能成功加载") + docs += load_file(file, SENTENCE_SIZE) + print(f"{file} 已成功加载") + loaded_files.append(file) if len(docs) > 0: print("文件加载完毕,正在生成向量库") if vs_path and os.path.isdir(vs_path): - self.vector_store = FAISS.load_local(vs_path, text2vec) - self.vector_store.add_documents(docs) + try: + self.vector_store = FAISS.load_local(vs_path, text2vec) + self.vector_store.add_documents(docs) + except: + self.vector_store = FAISS.from_documents(docs, text2vec) else: - if not vs_path: assert False self.vector_store = FAISS.from_documents(docs, text2vec) # docs 为Document列表 self.vector_store.save_local(vs_path) @@ -208,9 +202,9 @@ class LocalDocQA: self.vector_store = FAISS.load_local(vs_path, text2vec) return vs_path, loaded_files - def get_loaded_file(self): + def get_loaded_file(self, vs_path): ds = self.vector_store.docstore - return set([ds._dict[k].metadata['source'].split(UPLOAD_ROOT_PATH)[-1] for k in ds._dict]) + return set([ds._dict[k].metadata['source'].split(vs_path)[-1] for k in ds._dict]) # query 查询内容 @@ -228,7 +222,7 @@ class LocalDocQA: self.vector_store.score_threshold = score_threshold self.vector_store.chunk_size = chunk_size - embedding = self.vector_store.embedding_function(query) + embedding = self.vector_store.embedding_function.embed_query(query) related_docs_with_score = similarity_search_with_score_by_vector(self.vector_store, embedding, k=vector_search_top_k) if not related_docs_with_score: @@ -247,27 +241,23 @@ class LocalDocQA: -def construct_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation, text2vec): +def construct_vector_store(vs_id, vs_path, files, sentence_size, history, one_conent, one_content_segmentation, text2vec): for file in files: assert os.path.exists(file), "输入文件不存在" import nltk if NLTK_DATA_PATH not in nltk.data.path: nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path local_doc_qa = LocalDocQA() local_doc_qa.init_cfg() - vs_path = os.path.join(VS_ROOT_PATH, vs_id) filelist = [] - if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)): - os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id)) - if isinstance(files, list): - for file in files: - file_name = file.name if not isinstance(file, str) else file - filename = os.path.split(file_name)[-1] - shutil.copyfile(file_name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename)) - filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename)) - vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path, sentence_size, text2vec) - else: - vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, - sentence_size, text2vec) + if not os.path.exists(os.path.join(vs_path, vs_id)): + os.makedirs(os.path.join(vs_path, vs_id)) + for file in files: + file_name = file.name if not isinstance(file, str) else file + filename = os.path.split(file_name)[-1] + shutil.copyfile(file_name, os.path.join(vs_path, vs_id, filename)) + filelist.append(os.path.join(vs_path, vs_id, filename)) + vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, os.path.join(vs_path, vs_id), sentence_size, text2vec) + if len(loaded_files): file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" else: @@ -297,12 +287,13 @@ class knowledge_archive_interface(): return self.text2vec_large_chinese - def feed_archive(self, file_manifest, id="default"): + def feed_archive(self, file_manifest, vs_path, id="default"): self.threadLock.acquire() # import uuid self.current_id = id self.qa_handle, self.kai_path = construct_vector_store( vs_id=self.current_id, + vs_path=vs_path, files=file_manifest, sentence_size=100, history=[], @@ -315,15 +306,16 @@ class knowledge_archive_interface(): def get_current_archive_id(self): return self.current_id - def get_loaded_file(self): - return self.qa_handle.get_loaded_file() + def get_loaded_file(self, vs_path): + return self.qa_handle.get_loaded_file(vs_path) - def answer_with_archive_by_id(self, txt, id): + def answer_with_archive_by_id(self, txt, id, vs_path): self.threadLock.acquire() if not self.current_id == id: self.current_id = id self.qa_handle, self.kai_path = construct_vector_store( vs_id=self.current_id, + vs_path=vs_path, files=[], sentence_size=100, history=[], diff --git a/crazy_functions/知识库问答.py b/crazy_functions/知识库问答.py index 3015328..4898835 100644 --- a/crazy_functions/知识库问答.py +++ b/crazy_functions/知识库问答.py @@ -1,9 +1,10 @@ -from toolbox import CatchException, update_ui, ProxyNetworkActivate, update_ui_lastest_msg +from toolbox import CatchException, update_ui, ProxyNetworkActivate, update_ui_lastest_msg, get_log_folder, get_user from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_files_from_everything install_msg =""" pip3 install torch --index-url https://download.pytorch.org/whl/cpu -pip3 install langchain sentence-transformers unstructured[local-inference] faiss-cpu nltk beautifulsoup4 bitsandbytes tabulate icetk +pip3 install transformers --upgrade +pip3 install langchain sentence-transformers unstructured[all-docs] faiss-cpu nltk beautifulsoup4 bitsandbytes tabulate icetk --upgrade """ @CatchException @@ -65,8 +66,9 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst print('Establishing knowledge archive ...') with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络 kai = knowledge_archive_interface() - kai.feed_archive(file_manifest=file_manifest, id=kai_id) - kai_files = kai.get_loaded_file() + vs_path = get_log_folder(user=get_user(chatbot), plugin_name='vec_store') + kai.feed_archive(file_manifest=file_manifest, vs_path=vs_path, id=kai_id) + kai_files = kai.get_loaded_file(vs_path=vs_path) kai_files = '
'.join(kai_files) # chatbot.append(['知识库构建成功', "正在将知识库存储至cookie中"]) # yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 @@ -96,7 +98,8 @@ def 读取知识库作答(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg") kai_id = plugin_kwargs.get("advanced_arg", 'default') - resp, prompt = kai.answer_with_archive_by_id(txt, kai_id) + vs_path = get_log_folder(user=get_user(chatbot), plugin_name='vec_store') + resp, prompt = kai.answer_with_archive_by_id(txt, kai_id, vs_path) chatbot.append((txt, f'[知识库 {kai_id}] ' + prompt)) yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新 diff --git a/tests/test_utils.py b/tests/test_utils.py index 346f58f..c87908f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -49,18 +49,18 @@ class VoidTerminal(): pass vt = VoidTerminal() -vt.get_conf = (get_conf) -vt.set_conf = (set_conf) -vt.set_multi_conf = (set_multi_conf) -vt.get_plugin_handle = (get_plugin_handle) -vt.get_plugin_default_kwargs = (get_plugin_default_kwargs) -vt.get_chat_handle = (get_chat_handle) -vt.get_chat_default_kwargs = (get_chat_default_kwargs) +vt.get_conf = silence_stdout_fn(get_conf) +vt.set_conf = silence_stdout_fn(set_conf) +vt.set_multi_conf = silence_stdout_fn(set_multi_conf) +vt.get_plugin_handle = silence_stdout_fn(get_plugin_handle) +vt.get_plugin_default_kwargs = silence_stdout_fn(get_plugin_default_kwargs) +vt.get_chat_handle = silence_stdout_fn(get_chat_handle) +vt.get_chat_default_kwargs = silence_stdout_fn(get_chat_default_kwargs) vt.chat_to_markdown_str = (chat_to_markdown_str) proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY = \ vt.get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY') -def plugin_test(main_input, plugin, advanced_arg=None): +def plugin_test(main_input, plugin, advanced_arg=None, debug=True): from rich.live import Live from rich.markdown import Markdown @@ -72,7 +72,10 @@ def plugin_test(main_input, plugin, advanced_arg=None): plugin_kwargs['main_input'] = main_input if advanced_arg is not None: plugin_kwargs['plugin_kwargs'] = advanced_arg - my_working_plugin = silence_stdout(plugin)(**plugin_kwargs) + if debug: + my_working_plugin = (plugin)(**plugin_kwargs) + else: + my_working_plugin = silence_stdout(plugin)(**plugin_kwargs) with Live(Markdown(""), auto_refresh=False, vertical_overflow="visible") as live: for cookies, chat, hist, msg in my_working_plugin: