liunux4odoo d0846f88cc - pydantic 限定为 v1,并统一项目中所有 pydantic 导入路径,为以后升级 v2 做准备
- 重构 api.py:
    - 按模块划分为不同的 router
    - 添加 openai 兼容的转发接口,项目默认使用该接口以实现模型负载均衡
    - 添加 /tools 接口,可以获取/调用编写的 agent tools
    - 移除所有 EmbeddingFuncAdapter,统一改用 get_Embeddings
    - 待办:
        - /chat/chat 接口改为 openai 兼容
        - 添加 /chat/kb_chat 接口,openai 兼容
        - 改变 ntlk/knowledge_base/logs 等数据目录位置
2024-03-06 13:51:34 +08:00

125 lines
4.4 KiB
Python

import os
import shutil
from configs import SCORE_THRESHOLD
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import get_Embeddings
from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple
class FaissKBService(KBService):
vs_path: str
kb_path: str
vector_name: str = None
def vs_type(self) -> str:
return SupportedVSType.FAISS
def get_vs_path(self):
return get_vs_path(self.kb_name, self.vector_name)
def get_kb_path(self):
return get_kb_path(self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name=self.vector_name,
embed_model=self.embed_model)
def save_vector_store(self):
self.load_vector_store().save(self.vs_path)
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with self.load_vector_store().acquire() as vs:
return [vs.docstore._dict.get(id) for id in ids]
def del_doc_by_ids(self, ids: List[str]) -> bool:
with self.load_vector_store().acquire() as vs:
vs.delete(ids)
def do_init(self):
self.vector_name = self.vector_name or self.embed_model
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
self.load_vector_store()
def do_drop_kb(self):
self.clear_vs()
try:
shutil.rmtree(self.kb_path)
except Exception:
pass
def do_search(self,
query: str,
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
) -> List[Tuple[Document, float]]:
embed_func = get_Embeddings(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs:
embeddings = vs.embeddings.embed_query(query)
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
texts = [x.page_content for x in docs]
metadatas = [x.metadata for x in docs]
with self.load_vector_store().acquire() as vs:
embeddings = vs.embeddings.embed_documents(texts)
ids = vs.add_embeddings(text_embeddings=zip(texts, embeddings),
metadatas=metadatas)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
with self.load_vector_store().acquire() as vs:
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()]
if len(ids) > 0:
vs.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
return ids
def do_clear_vs(self):
with kb_faiss_pool.atomic:
kb_faiss_pool.pop((self.kb_name, self.vector_name))
try:
shutil.rmtree(self.vs_path)
except Exception:
...
os.makedirs(self.vs_path, exist_ok=True)
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
return "in_db"
content_path = os.path.join(self.kb_path, "content")
if os.path.isfile(os.path.join(content_path, file_name)):
return "in_folder"
else:
return False
if __name__ == '__main__':
faissService = FaissKBService("test")
faissService.add_doc(KnowledgeFile("README.md", "test"))
faissService.delete_doc(KnowledgeFile("README.md", "test"))
faissService.do_drop_kb()
print(faissService.search_docs("如何启动api服务"))