From 997f8b2e3fb5b65eed6f2dfdc55e756dde543abc Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Wed, 9 Aug 2023 21:57:40 +0800 Subject: [PATCH] update knowledge base api: 1. list_kbs_from_db return all kbs by default instead of return kbs with file_count > 0 only. 2. KBServiceFactory.get_service_by_name could return a FaissKBService that not in the db --- server/db/repository/knowledge_base_repository.py | 4 ++-- server/knowledge_base/kb_api.py | 10 ++++++---- server/knowledge_base/kb_doc_api.py | 8 ++++---- server/knowledge_base/kb_service/base.py | 8 +++++--- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/server/db/repository/knowledge_base_repository.py b/server/db/repository/knowledge_base_repository.py index cef420a9..9f5ad408 100644 --- a/server/db/repository/knowledge_base_repository.py +++ b/server/db/repository/knowledge_base_repository.py @@ -13,8 +13,8 @@ def add_kb_to_db(session, kb_name, vs_type, embed_model): @with_session -def list_kbs_from_db(session): - kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > 0).all() +def list_kbs_from_db(session, min_file_count: int = -1): + kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > min_file_count).all() kbs = [kb[0] for kb in kbs] return kbs diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 5186dbc4..8df14cb3 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -12,7 +12,7 @@ async def list_kbs(): return ListResponse(data=list_kbs_from_db()) -async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]), +async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), vector_store_type: str = Body("faiss"), embed_model: str = Body(EMBEDDING_MODEL), ): @@ -22,15 +22,17 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]), if knowledge_base_name is None or knowledge_base_name.strip() == "": return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") - kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is not None: return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") - kb.create() + + kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) + kb.create_kb() return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") async def delete_kb( - knowledge_base_name: str = Body(..., examples=["kb_name"]) + knowledge_base_name: str = Body(..., examples=["samples"]) ): # Delete selected knowledge base if not validate_kb_name(knowledge_base_name): diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 8a256a8f..23f8ed04 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -12,7 +12,7 @@ from typing import Union async def list_docs( - knowledge_base_name: str = Body(..., examples=["kb_name"]) + knowledge_base_name: str = Body(..., examples=["samples"]) ): if not validate_kb_name(knowledge_base_name): return ListResponse(code=403, msg="Don't attack me", data=[]) @@ -60,7 +60,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") -async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]), +async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]), doc_name: str = Body(..., examples=["file_name"]), delete_content: bool = Body(False), ): @@ -82,7 +82,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]), async def update_doc( - knowledge_base_name: str = Body(..., examples=["kb_name"]), + knowledge_base_name: str = Body(..., examples=["samples"]), file_name: str = Body(..., examples=["file_name"]), ): ''' @@ -111,7 +111,7 @@ async def download_doc(): async def recreate_vector_store( - knowledge_base_name: str = Body(..., examples=["kb_name"]), + knowledge_base_name: str = Body(..., examples=["samples"]), allow_empty_kb: bool = Body(True), vs_type: str = Body("faiss"), ): diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 8a675f3b..7209dcad 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod import os -from functools import lru_cache from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document @@ -42,7 +41,7 @@ class KBService(ABC): """ if not os.path.exists(self.doc_path): os.makedirs(self.doc_path) - self.do_create_kb() + self.do_create_kb() status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model) return status @@ -194,12 +193,15 @@ class KBServiceFactory: from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. + from server.knowledge_base.kb_service.default_kb_service import DefaultKBService return DefaultKBService(kb_name) @staticmethod def get_service_by_name(kb_name: str ) -> KBService: - kb_name, vs_type, embed_model = load_kb_from_db(kb_name) + _, vs_type, embed_model = load_kb_from_db(kb_name) + if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db + vs_type = "faiss" return KBServiceFactory.get_service(kb_name, vs_type, embed_model) @staticmethod