From ff91508d8b2c23a1d51222cd3e95e89d6b4ea6c2 Mon Sep 17 00:00:00 2001 From: Lijia0 <30282949+Lijia0@users.noreply.github.com> Date: Fri, 12 Jan 2024 12:34:40 +0800 Subject: [PATCH] =?UTF-8?q?=E6=AF=8F=E6=AC=A1=E5=88=9B=E5=BB=BA=E5=AF=B9?= =?UTF-8?q?=E8=B1=A1=E6=97=B6=E4=BB=8E=E8=BF=9E=E6=8E=A5=E6=B1=A0=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E8=BF=9E=E6=8E=A5=EF=BC=8C=E9=81=BF=E5=85=8D=E6=AF=8F?= =?UTF-8?q?=E6=AC=A1=E6=89=A7=E8=A1=8C=E6=96=B9=E6=B3=95=E6=97=B6=E9=83=BD?= =?UTF-8?q?=E6=96=B0=E5=BB=BA=E8=BF=9E=E6=8E=A5=20(#2480)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kb_service/pg_kb_service.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 487e30a0..925bb4d2 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -11,22 +11,27 @@ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, Em score_threshold_process from server.knowledge_base.utils import KnowledgeFile import shutil +import sqlalchemy +from sqlalchemy.engine.base import Engine +from sqlalchemy.orm import Session class PGKBService(KBService): - pg_vector: PGVector + engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10) def _load_pg_vector(self): + self.connection = PGKBService.engine.connect() self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model), collection_name=self.kb_name, distance_strategy=DistanceStrategy.EUCLIDEAN, + connection=self.connection, connection_string=kbs_config.get("pg").get("connection_uri")) def get_doc_by_ids(self, ids: List[str]) -> List[Document]: - with self.pg_vector._create_engine().connect() as connect: + with Session(self.connection) as session: stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids") results = [Document(page_content=row[0], metadata=row[1]) for row in - connect.execute(stmt, parameters={'ids': ids}).fetchall()] + session.execute(stmt, {'ids': ids}).fetchall()] return results # TODO: @@ -43,8 +48,8 @@ class PGKBService(KBService): return SupportedVSType.PG def do_drop_kb(self): - with self.pg_vector._create_engine().connect() as connect: - connect.execute(text(f''' + with Session(self.connection) as session: + session.execute(text(f''' -- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录 DELETE FROM langchain_pg_embedding WHERE collection_id IN ( @@ -53,11 +58,10 @@ class PGKBService(KBService): -- 删除 langchain_pg_collection 表中 记录 DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}'; ''')) - connect.commit() + session.commit() shutil.rmtree(self.kb_path) def do_search(self, query: str, top_k: int, score_threshold: float): - self._load_pg_vector() embed_func = EmbeddingsFunAdapter(self.embed_model) embeddings = embed_func.embed_query(query) docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k) @@ -69,13 +73,13 @@ class PGKBService(KBService): return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - with self.pg_vector._create_engine().connect() as connect: + with Session(self.connection) as session: filepath = kb_file.filepath.replace('\\', '\\\\') - connect.execute( + session.execute( text( ''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace( "filepath", filepath))) - connect.commit() + session.commit() def do_clear_vs(self): self.pg_vector.delete_collection()