fix local vector store bug
This commit is contained in:
parent
8a6e96c369
commit
7bac8f4bd3
@ -26,10 +26,6 @@ EMBEDDING_MODEL = "text2vec"
|
|||||||
# Embedding running device
|
# Embedding running device
|
||||||
EMBEDDING_DEVICE = "cpu"
|
EMBEDDING_DEVICE = "cpu"
|
||||||
|
|
||||||
VS_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "vector_store")
|
|
||||||
|
|
||||||
UPLOAD_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "content")
|
|
||||||
|
|
||||||
# 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
|
# 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
|
||||||
PROMPT_TEMPLATE = """已知信息:
|
PROMPT_TEMPLATE = """已知信息:
|
||||||
{context}
|
{context}
|
||||||
@ -159,7 +155,7 @@ class LocalDocQA:
|
|||||||
elif os.path.isfile(filepath):
|
elif os.path.isfile(filepath):
|
||||||
file = os.path.split(filepath)[-1]
|
file = os.path.split(filepath)[-1]
|
||||||
try:
|
try:
|
||||||
docs = load_file(filepath, sentence_size)
|
docs = load_file(filepath, SENTENCE_SIZE)
|
||||||
print(f"{file} 已成功加载")
|
print(f"{file} 已成功加载")
|
||||||
loaded_files.append(filepath)
|
loaded_files.append(filepath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -171,7 +167,7 @@ class LocalDocQA:
|
|||||||
for file in tqdm(os.listdir(filepath), desc="加载文件"):
|
for file in tqdm(os.listdir(filepath), desc="加载文件"):
|
||||||
fullfilepath = os.path.join(filepath, file)
|
fullfilepath = os.path.join(filepath, file)
|
||||||
try:
|
try:
|
||||||
docs += load_file(fullfilepath, sentence_size)
|
docs += load_file(fullfilepath, SENTENCE_SIZE)
|
||||||
loaded_files.append(fullfilepath)
|
loaded_files.append(fullfilepath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
@ -185,21 +181,19 @@ class LocalDocQA:
|
|||||||
else:
|
else:
|
||||||
docs = []
|
docs = []
|
||||||
for file in filepath:
|
for file in filepath:
|
||||||
try:
|
docs += load_file(file, SENTENCE_SIZE)
|
||||||
docs += load_file(file)
|
print(f"{file} 已成功加载")
|
||||||
print(f"{file} 已成功加载")
|
loaded_files.append(file)
|
||||||
loaded_files.append(file)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
print(f"{file} 未能成功加载")
|
|
||||||
|
|
||||||
if len(docs) > 0:
|
if len(docs) > 0:
|
||||||
print("文件加载完毕,正在生成向量库")
|
print("文件加载完毕,正在生成向量库")
|
||||||
if vs_path and os.path.isdir(vs_path):
|
if vs_path and os.path.isdir(vs_path):
|
||||||
self.vector_store = FAISS.load_local(vs_path, text2vec)
|
try:
|
||||||
self.vector_store.add_documents(docs)
|
self.vector_store = FAISS.load_local(vs_path, text2vec)
|
||||||
|
self.vector_store.add_documents(docs)
|
||||||
|
except:
|
||||||
|
self.vector_store = FAISS.from_documents(docs, text2vec)
|
||||||
else:
|
else:
|
||||||
if not vs_path: assert False
|
|
||||||
self.vector_store = FAISS.from_documents(docs, text2vec) # docs 为Document列表
|
self.vector_store = FAISS.from_documents(docs, text2vec) # docs 为Document列表
|
||||||
|
|
||||||
self.vector_store.save_local(vs_path)
|
self.vector_store.save_local(vs_path)
|
||||||
@ -208,9 +202,9 @@ class LocalDocQA:
|
|||||||
self.vector_store = FAISS.load_local(vs_path, text2vec)
|
self.vector_store = FAISS.load_local(vs_path, text2vec)
|
||||||
return vs_path, loaded_files
|
return vs_path, loaded_files
|
||||||
|
|
||||||
def get_loaded_file(self):
|
def get_loaded_file(self, vs_path):
|
||||||
ds = self.vector_store.docstore
|
ds = self.vector_store.docstore
|
||||||
return set([ds._dict[k].metadata['source'].split(UPLOAD_ROOT_PATH)[-1] for k in ds._dict])
|
return set([ds._dict[k].metadata['source'].split(vs_path)[-1] for k in ds._dict])
|
||||||
|
|
||||||
|
|
||||||
# query 查询内容
|
# query 查询内容
|
||||||
@ -228,7 +222,7 @@ class LocalDocQA:
|
|||||||
self.vector_store.score_threshold = score_threshold
|
self.vector_store.score_threshold = score_threshold
|
||||||
self.vector_store.chunk_size = chunk_size
|
self.vector_store.chunk_size = chunk_size
|
||||||
|
|
||||||
embedding = self.vector_store.embedding_function(query)
|
embedding = self.vector_store.embedding_function.embed_query(query)
|
||||||
related_docs_with_score = similarity_search_with_score_by_vector(self.vector_store, embedding, k=vector_search_top_k)
|
related_docs_with_score = similarity_search_with_score_by_vector(self.vector_store, embedding, k=vector_search_top_k)
|
||||||
|
|
||||||
if not related_docs_with_score:
|
if not related_docs_with_score:
|
||||||
@ -247,27 +241,23 @@ class LocalDocQA:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def construct_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation, text2vec):
|
def construct_vector_store(vs_id, vs_path, files, sentence_size, history, one_conent, one_content_segmentation, text2vec):
|
||||||
for file in files:
|
for file in files:
|
||||||
assert os.path.exists(file), "输入文件不存在"
|
assert os.path.exists(file), "输入文件不存在"
|
||||||
import nltk
|
import nltk
|
||||||
if NLTK_DATA_PATH not in nltk.data.path: nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
if NLTK_DATA_PATH not in nltk.data.path: nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
local_doc_qa = LocalDocQA()
|
local_doc_qa = LocalDocQA()
|
||||||
local_doc_qa.init_cfg()
|
local_doc_qa.init_cfg()
|
||||||
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
|
||||||
filelist = []
|
filelist = []
|
||||||
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)):
|
if not os.path.exists(os.path.join(vs_path, vs_id)):
|
||||||
os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id))
|
os.makedirs(os.path.join(vs_path, vs_id))
|
||||||
if isinstance(files, list):
|
for file in files:
|
||||||
for file in files:
|
file_name = file.name if not isinstance(file, str) else file
|
||||||
file_name = file.name if not isinstance(file, str) else file
|
filename = os.path.split(file_name)[-1]
|
||||||
filename = os.path.split(file_name)[-1]
|
shutil.copyfile(file_name, os.path.join(vs_path, vs_id, filename))
|
||||||
shutil.copyfile(file_name, os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
|
filelist.append(os.path.join(vs_path, vs_id, filename))
|
||||||
filelist.append(os.path.join(UPLOAD_ROOT_PATH, vs_id, filename))
|
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, os.path.join(vs_path, vs_id), sentence_size, text2vec)
|
||||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(filelist, vs_path, sentence_size, text2vec)
|
|
||||||
else:
|
|
||||||
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
|
||||||
sentence_size, text2vec)
|
|
||||||
if len(loaded_files):
|
if len(loaded_files):
|
||||||
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
||||||
else:
|
else:
|
||||||
@ -297,12 +287,13 @@ class knowledge_archive_interface():
|
|||||||
return self.text2vec_large_chinese
|
return self.text2vec_large_chinese
|
||||||
|
|
||||||
|
|
||||||
def feed_archive(self, file_manifest, id="default"):
|
def feed_archive(self, file_manifest, vs_path, id="default"):
|
||||||
self.threadLock.acquire()
|
self.threadLock.acquire()
|
||||||
# import uuid
|
# import uuid
|
||||||
self.current_id = id
|
self.current_id = id
|
||||||
self.qa_handle, self.kai_path = construct_vector_store(
|
self.qa_handle, self.kai_path = construct_vector_store(
|
||||||
vs_id=self.current_id,
|
vs_id=self.current_id,
|
||||||
|
vs_path=vs_path,
|
||||||
files=file_manifest,
|
files=file_manifest,
|
||||||
sentence_size=100,
|
sentence_size=100,
|
||||||
history=[],
|
history=[],
|
||||||
@ -315,15 +306,16 @@ class knowledge_archive_interface():
|
|||||||
def get_current_archive_id(self):
|
def get_current_archive_id(self):
|
||||||
return self.current_id
|
return self.current_id
|
||||||
|
|
||||||
def get_loaded_file(self):
|
def get_loaded_file(self, vs_path):
|
||||||
return self.qa_handle.get_loaded_file()
|
return self.qa_handle.get_loaded_file(vs_path)
|
||||||
|
|
||||||
def answer_with_archive_by_id(self, txt, id):
|
def answer_with_archive_by_id(self, txt, id, vs_path):
|
||||||
self.threadLock.acquire()
|
self.threadLock.acquire()
|
||||||
if not self.current_id == id:
|
if not self.current_id == id:
|
||||||
self.current_id = id
|
self.current_id = id
|
||||||
self.qa_handle, self.kai_path = construct_vector_store(
|
self.qa_handle, self.kai_path = construct_vector_store(
|
||||||
vs_id=self.current_id,
|
vs_id=self.current_id,
|
||||||
|
vs_path=vs_path,
|
||||||
files=[],
|
files=[],
|
||||||
sentence_size=100,
|
sentence_size=100,
|
||||||
history=[],
|
history=[],
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
from toolbox import CatchException, update_ui, ProxyNetworkActivate, update_ui_lastest_msg
|
from toolbox import CatchException, update_ui, ProxyNetworkActivate, update_ui_lastest_msg, get_log_folder, get_user
|
||||||
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_files_from_everything
|
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_files_from_everything
|
||||||
|
|
||||||
install_msg ="""
|
install_msg ="""
|
||||||
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
|
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
|
||||||
pip3 install langchain sentence-transformers unstructured[local-inference] faiss-cpu nltk beautifulsoup4 bitsandbytes tabulate icetk
|
pip3 install transformers --upgrade
|
||||||
|
pip3 install langchain sentence-transformers unstructured[all-docs] faiss-cpu nltk beautifulsoup4 bitsandbytes tabulate icetk --upgrade
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@CatchException
|
@CatchException
|
||||||
@ -65,8 +66,9 @@ def 知识库文件注入(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
|||||||
print('Establishing knowledge archive ...')
|
print('Establishing knowledge archive ...')
|
||||||
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
|
||||||
kai = knowledge_archive_interface()
|
kai = knowledge_archive_interface()
|
||||||
kai.feed_archive(file_manifest=file_manifest, id=kai_id)
|
vs_path = get_log_folder(user=get_user(chatbot), plugin_name='vec_store')
|
||||||
kai_files = kai.get_loaded_file()
|
kai.feed_archive(file_manifest=file_manifest, vs_path=vs_path, id=kai_id)
|
||||||
|
kai_files = kai.get_loaded_file(vs_path=vs_path)
|
||||||
kai_files = '<br/>'.join(kai_files)
|
kai_files = '<br/>'.join(kai_files)
|
||||||
# chatbot.append(['知识库构建成功', "正在将知识库存储至cookie中"])
|
# chatbot.append(['知识库构建成功', "正在将知识库存储至cookie中"])
|
||||||
# yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
# yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
@ -96,7 +98,8 @@ def 读取知识库作答(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
|||||||
|
|
||||||
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
|
||||||
kai_id = plugin_kwargs.get("advanced_arg", 'default')
|
kai_id = plugin_kwargs.get("advanced_arg", 'default')
|
||||||
resp, prompt = kai.answer_with_archive_by_id(txt, kai_id)
|
vs_path = get_log_folder(user=get_user(chatbot), plugin_name='vec_store')
|
||||||
|
resp, prompt = kai.answer_with_archive_by_id(txt, kai_id, vs_path)
|
||||||
|
|
||||||
chatbot.append((txt, f'[知识库 {kai_id}] ' + prompt))
|
chatbot.append((txt, f'[知识库 {kai_id}] ' + prompt))
|
||||||
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
|
||||||
|
@ -49,18 +49,18 @@ class VoidTerminal():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
vt = VoidTerminal()
|
vt = VoidTerminal()
|
||||||
vt.get_conf = (get_conf)
|
vt.get_conf = silence_stdout_fn(get_conf)
|
||||||
vt.set_conf = (set_conf)
|
vt.set_conf = silence_stdout_fn(set_conf)
|
||||||
vt.set_multi_conf = (set_multi_conf)
|
vt.set_multi_conf = silence_stdout_fn(set_multi_conf)
|
||||||
vt.get_plugin_handle = (get_plugin_handle)
|
vt.get_plugin_handle = silence_stdout_fn(get_plugin_handle)
|
||||||
vt.get_plugin_default_kwargs = (get_plugin_default_kwargs)
|
vt.get_plugin_default_kwargs = silence_stdout_fn(get_plugin_default_kwargs)
|
||||||
vt.get_chat_handle = (get_chat_handle)
|
vt.get_chat_handle = silence_stdout_fn(get_chat_handle)
|
||||||
vt.get_chat_default_kwargs = (get_chat_default_kwargs)
|
vt.get_chat_default_kwargs = silence_stdout_fn(get_chat_default_kwargs)
|
||||||
vt.chat_to_markdown_str = (chat_to_markdown_str)
|
vt.chat_to_markdown_str = (chat_to_markdown_str)
|
||||||
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY = \
|
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY = \
|
||||||
vt.get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY')
|
vt.get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY')
|
||||||
|
|
||||||
def plugin_test(main_input, plugin, advanced_arg=None):
|
def plugin_test(main_input, plugin, advanced_arg=None, debug=True):
|
||||||
from rich.live import Live
|
from rich.live import Live
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
@ -72,7 +72,10 @@ def plugin_test(main_input, plugin, advanced_arg=None):
|
|||||||
plugin_kwargs['main_input'] = main_input
|
plugin_kwargs['main_input'] = main_input
|
||||||
if advanced_arg is not None:
|
if advanced_arg is not None:
|
||||||
plugin_kwargs['plugin_kwargs'] = advanced_arg
|
plugin_kwargs['plugin_kwargs'] = advanced_arg
|
||||||
my_working_plugin = silence_stdout(plugin)(**plugin_kwargs)
|
if debug:
|
||||||
|
my_working_plugin = (plugin)(**plugin_kwargs)
|
||||||
|
else:
|
||||||
|
my_working_plugin = silence_stdout(plugin)(**plugin_kwargs)
|
||||||
|
|
||||||
with Live(Markdown(""), auto_refresh=False, vertical_overflow="visible") as live:
|
with Live(Markdown(""), auto_refresh=False, vertical_overflow="visible") as live:
|
||||||
for cookies, chat, hist, msg in my_working_plugin:
|
for cookies, chat, hist, msg in my_working_plugin:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user