liunux4odoo 55e417a263 升级注意
数据库表发生变化,需要重建知识库

 新功能
- 增加FileDocModel库表,存储知识文件与向量库Document ID对应关系以及元数据,便于检索向量库
- 增加FileDocModel对应的数据库操作函数(这些函数主要是给KBService调用,用户一般无需使用):
  - list_docs_from_db: 根据知识库名称、文件名称、元数据检索对应的Document IDs
  - delete_docs_from_db: 根据知识库名称、文件名称删除对应的file-doc映射
  - add_docs_to_db: 添加对应的file-doc映射
- KBService增加list_docs方法,可以根据文件名、元数据检索Document。当前仅支持FAISS,待milvus/pg实现get_doc_by_id方法后即自动支持。
- 去除server.utils对torch的依赖

 待完善
- milvus/pg kb_service需要实现get_doc_by_id方法
2023-09-01 22:54:57 +08:00

86 lines
3.5 KiB
Python

from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text
from configs.model_config import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
from server.utils import embedding_device as get_embedding_device
class PGKBService(KBService):
pg_vector: PGVector
def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
_embeddings = embeddings
if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device)
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings),
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri"))
# TODO:
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
def do_init(self):
self._load_pg_vector()
def do_create_kb(self):
pass
def vs_type(self) -> str:
return SupportedVSType.PG
def do_drop_kb(self):
with self.pg_vector.connect() as connect:
connect.execute(text(f'''
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
DELETE FROM langchain_pg_embedding
WHERE collection_id IN (
SELECT uuid FROM langchain_pg_collection WHERE name = '{self.kb_name}'
);
-- 删除 langchain_pg_collection 表中 记录
DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}';
'''))
connect.commit()
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings):
self._load_pg_vector(embeddings=embeddings)
return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect:
filepath = kb_file.filepath.replace('\\', '\\\\')
connect.execute(
text(
''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace(
"filepath", filepath)))
connect.commit()
def do_clear_vs(self):
self.pg_vector.delete_collection()
if __name__ == '__main__':
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
pGKBService = PGKBService("test")
pGKBService.create_kb()
pGKBService.add_doc(KnowledgeFile("README.md", "test"))
pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
pGKBService.drop_kb()
print(pGKBService.search_docs("如何启动api服务"))