Update pgvector connection method following updates in langchain_community, to resolve the 'PGVector' object has no attribute 'connect' error. (#2591)

This commit is contained in:
HALIndex 2024-01-12 10:17:04 +08:00 committed by GitHub
parent 03eb5e9d2e
commit 29ef5dda64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,7 +23,7 @@ class PGKBService(KBService):
connection_string=kbs_config.get("pg").get("connection_uri"))
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with self.pg_vector.connect() as connect:
with self.pg_vector._create_engine().connect() as connect:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
results = [Document(page_content=row[0], metadata=row[1]) for row in
connect.execute(stmt, parameters={'ids': ids}).fetchall()]
@ -43,7 +43,7 @@ class PGKBService(KBService):
return SupportedVSType.PG
def do_drop_kb(self):
with self.pg_vector.connect() as connect:
with self.pg_vector._create_engine().connect() as connect:
connect.execute(text(f'''
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
DELETE FROM langchain_pg_embedding
@ -69,7 +69,7 @@ class PGKBService(KBService):
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect:
with self.pg_vector._create_engine().connect() as connect:
filepath = kb_file.filepath.replace('\\', '\\\\')
connect.execute(
text(