diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 487e30a0..925bb4d2 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -11,22 +11,27 @@ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, Em score_threshold_process from server.knowledge_base.utils import KnowledgeFile import shutil +import sqlalchemy +from sqlalchemy.engine.base import Engine +from sqlalchemy.orm import Session class PGKBService(KBService): - pg_vector: PGVector + engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10) def _load_pg_vector(self): + self.connection = PGKBService.engine.connect() self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model), collection_name=self.kb_name, distance_strategy=DistanceStrategy.EUCLIDEAN, + connection=self.connection, connection_string=kbs_config.get("pg").get("connection_uri")) def get_doc_by_ids(self, ids: List[str]) -> List[Document]: - with self.pg_vector._create_engine().connect() as connect: + with Session(self.connection) as session: stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids") results = [Document(page_content=row[0], metadata=row[1]) for row in - connect.execute(stmt, parameters={'ids': ids}).fetchall()] + session.execute(stmt, {'ids': ids}).fetchall()] return results # TODO: @@ -43,8 +48,8 @@ class PGKBService(KBService): return SupportedVSType.PG def do_drop_kb(self): - with self.pg_vector._create_engine().connect() as connect: - connect.execute(text(f''' + with Session(self.connection) as session: + session.execute(text(f''' -- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录 DELETE FROM langchain_pg_embedding WHERE collection_id IN ( @@ -53,11 +58,10 @@ class PGKBService(KBService): -- 删除 langchain_pg_collection 表中 记录 DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}'; ''')) - connect.commit() + session.commit() shutil.rmtree(self.kb_path) def do_search(self, query: str, top_k: int, score_threshold: float): - self._load_pg_vector() embed_func = EmbeddingsFunAdapter(self.embed_model) embeddings = embed_func.embed_query(query) docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) @@ -69,13 +73,13 @@ class PGKBService(KBService): return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - with self.pg_vector._create_engine().connect() as connect: + with Session(self.connection) as session: filepath = kb_file.filepath.replace('\\', '\\\\') - connect.execute( + session.execute( text( ''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace( "filepath", filepath))) - connect.commit() + session.commit() def do_clear_vs(self): self.pg_vector.delete_collection()