liunux4odoo 9ff7bef2c2
新功能:知识库管理界面支持查看、编辑、删除向量库文档 (#2471)
* 新功能:
- 知识库管理界面支持查看、编辑、删除向量库文档。暂不支持增加(aggrid添加新行比较麻烦,需要另外实现)
- 去除知识库管理界面中重建知识库和删除知识库按钮,建议用户到终端命令操作

修复:
- 所有与知识库名称、文件名称有关的数据库操作函数都改成大小写不敏感,所有路径统一为 posix 风格,避免因路径文本不一致导致数据重复和操作失效 (close #2232)

开发者:
- 添加 update_docs_by_id 函数与 API 接口。当前仅支持 FAISS,暂时未用到,未将来对知识库做更细致的修改做准备
- 统一 DocumentWithScore 与 DocumentWithVsId
- FAISS 返回的 Document.metadata 中包含 ID, 方便后续查找比对
- /knowledge_base/search_docs 接口支持 file_name, metadata 参数,可以据此检索文档

* fix bug
2023-12-26 13:44:36 +08:00

178 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
from server.knowledge_base.kb_cache.base import *
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.utils import load_local_embeddings
from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores.faiss import FAISS
from langchain.docstore.in_memory import InMemoryDocstore
from langchain.schema import Document
import os
from langchain.schema import Document
# patch FAISS to include doc id in Document.metadata
def _new_ds_search(self, search: str) -> Union[str, Document]:
if search not in self._dict:
return f"ID {search} not found."
else:
doc = self._dict[search]
if isinstance(doc, Document):
doc.metadata["id"] = search
return doc
InMemoryDocstore.search = _new_ds_search
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int:
return len(self._obj.docstore._dict)
def save(self, path: str, create_path: bool = True):
with self.acquire():
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self.key} 保存到磁盘")
return ret
def clear(self):
ret = []
with self.acquire():
ids = list(self._obj.docstore._dict.keys())
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self.key} 清空")
return ret
class _FaissPool(CachePool):
def new_vector_store(
self,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> FAISS:
# TODO: 整个Embeddings加载逻辑有些混乱待清理
# create an empty vector store
embeddings = EmbeddingsFunAdapter(embed_model)
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
def save_vector_store(self, kb_name: str, path: str=None):
if cache := self.get(kb_name):
return cache.save(path)
def unload_vector_store(self, kb_name: str):
if cache := self.get(kb_name):
self.pop(kb_name)
logger.info(f"成功释放向量库:{kb_name}")
class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
vector_name: str = None,
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
vector_name = vector_name or embed_model
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None:
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
vector_store.save_local(vs_path)
else:
raise RuntimeError(f"knowledge base {kb_name} not exist.")
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get((kb_name, vector_name))
class MemoFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' to memory.")
# create an empty vector store
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)
if __name__ == "__main__":
import time, random
from pprint import pprint
kb_names = ["vs1", "vs2", "vs3"]
# for name in kb_names:
# memo_faiss_pool.load_vector_store(name)
def worker(vs_name: str, name: str):
vs_name = "samples"
time.sleep(random.randint(1, 5))
embeddings = load_local_embeddings()
r = random.randint(1, 3)
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
if r == 1: # add docs
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
pprint(ids)
elif r == 2: # search docs
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
pprint(docs)
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")
kb_faiss_pool.get(vs_name).clear()
threads = []
for n in range(1, 30):
t = threading.Thread(target=worker,
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
daemon=True)
t.start()
threads.append(t)
for t in threads:
t.join()