update knowledge base api:

1. list_kbs_from_db return all kbs by default instead of return kbs with
   file_count > 0 only.
2. KBServiceFactory.get_service_by_name could return a FaissKBService
   that not in the db
This commit is contained in:
liunux4odoo 2023-08-09 21:57:40 +08:00
parent 323fc13d4c
commit 997f8b2e3f
4 changed files with 17 additions and 13 deletions

View File

@ -13,8 +13,8 @@ def add_kb_to_db(session, kb_name, vs_type, embed_model):
@with_session @with_session
def list_kbs_from_db(session): def list_kbs_from_db(session, min_file_count: int = -1):
kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > 0).all() kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > min_file_count).all()
kbs = [kb[0] for kb in kbs] kbs = [kb[0] for kb in kbs]
return kbs return kbs

View File

@ -12,7 +12,7 @@ async def list_kbs():
return ListResponse(data=list_kbs_from_db()) return ListResponse(data=list_kbs_from_db())
async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]), async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"), vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL), embed_model: str = Body(EMBEDDING_MODEL),
): ):
@ -22,15 +22,17 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]),
if knowledge_base_name is None or knowledge_base_name.strip() == "": if knowledge_base_name is None or knowledge_base_name.strip() == "":
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称") return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is not None: if kb is not None:
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb.create()
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
kb.create_kb()
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb( async def delete_kb(
knowledge_base_name: str = Body(..., examples=["kb_name"]) knowledge_base_name: str = Body(..., examples=["samples"])
): ):
# Delete selected knowledge base # Delete selected knowledge base
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):

View File

@ -12,7 +12,7 @@ from typing import Union
async def list_docs( async def list_docs(
knowledge_base_name: str = Body(..., examples=["kb_name"]) knowledge_base_name: str = Body(..., examples=["samples"])
): ):
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[]) return ListResponse(code=403, msg="Don't attack me", data=[])
@ -60,7 +60,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}") return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]), async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name"]), doc_name: str = Body(..., examples=["file_name"]),
delete_content: bool = Body(False), delete_content: bool = Body(False),
): ):
@ -82,7 +82,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]),
async def update_doc( async def update_doc(
knowledge_base_name: str = Body(..., examples=["kb_name"]), knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]), file_name: str = Body(..., examples=["file_name"]),
): ):
''' '''
@ -111,7 +111,7 @@ async def download_doc():
async def recreate_vector_store( async def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["kb_name"]), knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True), allow_empty_kb: bool = Body(True),
vs_type: str = Body("faiss"), vs_type: str = Body("faiss"),
): ):

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import os import os
from functools import lru_cache
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
@ -42,7 +41,7 @@ class KBService(ABC):
""" """
if not os.path.exists(self.doc_path): if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path) os.makedirs(self.doc_path)
self.do_create_kb() self.do_create_kb()
status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model) status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model)
return status return status
@ -194,12 +193,15 @@ class KBServiceFactory:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name) return DefaultKBService(kb_name)
@staticmethod @staticmethod
def get_service_by_name(kb_name: str def get_service_by_name(kb_name: str
) -> KBService: ) -> KBService:
kb_name, vs_type, embed_model = load_kb_from_db(kb_name) _, vs_type, embed_model = load_kb_from_db(kb_name)
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
vs_type = "faiss"
return KBServiceFactory.get_service(kb_name, vs_type, embed_model) return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@staticmethod @staticmethod