diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index e221e438..fa079f62 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -172,15 +172,15 @@ class KBService(ABC): docs = self.do_search(query, top_k, score_threshold) return docs - def get_doc_by_id(self, id: str) -> Optional[Document]: - return None + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: + return [] def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]: ''' 通过file_name或metadata检索Document ''' doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata) - docs = [self.get_doc_by_id(x["id"]) for x in doc_infos] + docs = self.get_doc_by_ids([x["id"] for x in doc_infos]) return docs @abstractmethod diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 07c57e05..a444f038 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -32,9 +32,9 @@ class FaissKBService(KBService): def save_vector_store(self): self.load_vector_store().save(self.vs_path) - def get_doc_by_id(self, id: str) -> Optional[Document]: + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: with self.load_vector_store().acquire() as vs: - return vs.docstore._dict.get(id) + return [vs.docstore._dict.get(id) for id in ids] def do_init(self): self.vector_name = self.vector_name or self.embed_model diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 5e270400..97c5913e 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -22,13 +22,14 @@ class MilvusKBService(KBService): # if self.milvus.col: # self.milvus.col.flush() - def get_doc_by_id(self, id: str) -> Optional[Document]: + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: + result = [] if self.milvus.col: - data_list = self.milvus.col.query(expr=f'pk == {id}', output_fields=["*"]) - if len(data_list) > 0: - data = data_list[0] + data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"]) + for data in data_list: text = data.pop("text") - return Document(page_content=text, metadata=data) + result.append(Document(page_content=text, metadata=data)) + return result @staticmethod def search(milvus_name, content, limit=3): @@ -99,7 +100,7 @@ if __name__ == '__main__': milvusService = MilvusKBService("test") # milvusService.add_doc(KnowledgeFile("README.md", "test")) - print(milvusService.get_doc_by_id("444022434274215486")) + print(milvusService.get_doc_by_ids(["444022434274215486"])) # milvusService.delete_doc(KnowledgeFile("README.md", "test")) # milvusService.do_drop_kb() # print(milvusService.search_docs("如何启动api服务")) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 3a6bab0b..cf58ce37 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -22,13 +22,12 @@ class PGKBService(KBService): distance_strategy=DistanceStrategy.EUCLIDEAN, connection_string=kbs_config.get("pg").get("connection_uri")) - def get_doc_by_id(self, id: str) -> Optional[Document]: + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: with self.pg_vector.connect() as connect: - stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id=:id") + stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids") results = [Document(page_content=row[0], metadata=row[1]) for row in - connect.execute(stmt, parameters={'id': id}).fetchall()] - if len(results) > 0: - return results[0] + connect.execute(stmt, parameters={'ids': ids}).fetchall()] + return results def do_init(self): self._load_pg_vector() @@ -88,5 +87,5 @@ if __name__ == '__main__': # pGKBService.add_doc(KnowledgeFile("README.md", "test")) # pGKBService.delete_doc(KnowledgeFile("README.md", "test")) # pGKBService.drop_kb() - print(pGKBService.get_doc_by_id("f1e51390-3029-4a19-90dc-7118aaa25772")) + print(pGKBService.get_doc_by_ids(["f1e51390-3029-4a19-90dc-7118aaa25772"])) # print(pGKBService.search_docs("如何启动api服务")) diff --git a/server/knowledge_base/kb_service/zilliz_kb_service.py b/server/knowledge_base/kb_service/zilliz_kb_service.py index bd8b3e9f..d82f8734 100644 --- a/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -20,13 +20,14 @@ class ZillizKBService(KBService): # if self.zilliz.col: # self.zilliz.col.flush() - def get_doc_by_id(self, id: str) -> Optional[Document]: + def get_doc_by_ids(self, ids: List[str]) -> List[Document]: + result = [] if self.zilliz.col: - data_list = self.zilliz.col.query(expr=f'pk == {id}', output_fields=["*"]) - if len(data_list) > 0: - data = data_list[0] + data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"]) + for data in data_list: text = data.pop("text") - return Document(page_content=text, metadata=data) + result.append(Document(page_content=text, metadata=data)) + return result @staticmethod def search(zilliz_name, content, limit=3):