update knowledge base api:

1. list_kbs_from_db return all kbs by default instead of return kbs with
   file_count > 0 only.
2. KBServiceFactory.get_service_by_name could return a FaissKBService
   that not in the db
This commit is contained in:
liunux4odoo 2023-08-09 21:57:40 +08:00
parent 323fc13d4c
commit 997f8b2e3f
4 changed files with 17 additions and 13 deletions

View File

@ -13,8 +13,8 @@ def add_kb_to_db(session, kb_name, vs_type, embed_model):
@with_session
def list_kbs_from_db(session):
kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > 0).all()
def list_kbs_from_db(session, min_file_count: int = -1):
kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > min_file_count).all()
kbs = [kb[0] for kb in kbs]
return kbs

View File

@ -12,7 +12,7 @@ async def list_kbs():
return ListResponse(data=list_kbs_from_db())
async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]),
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
):
@ -22,15 +22,17 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]),
if knowledge_base_name is None or knowledge_base_name.strip() == "":
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is not None:
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb.create()
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
kb.create_kb()
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
async def delete_kb(
knowledge_base_name: str = Body(..., examples=["kb_name"])
knowledge_base_name: str = Body(..., examples=["samples"])
):
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):

View File

@ -12,7 +12,7 @@ from typing import Union
async def list_docs(
knowledge_base_name: str = Body(..., examples=["kb_name"])
knowledge_base_name: str = Body(..., examples=["samples"])
):
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])
@ -60,7 +60,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]),
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
doc_name: str = Body(..., examples=["file_name"]),
delete_content: bool = Body(False),
):
@ -82,7 +82,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]),
async def update_doc(
knowledge_base_name: str = Body(..., examples=["kb_name"]),
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["file_name"]),
):
'''
@ -111,7 +111,7 @@ async def download_doc():
async def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["kb_name"]),
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body("faiss"),
):

View File

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
import os
from functools import lru_cache
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
@ -42,7 +41,7 @@ class KBService(ABC):
"""
if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path)
self.do_create_kb()
self.do_create_kb()
status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model)
return status
@ -194,12 +193,15 @@ class KBServiceFactory:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)
@staticmethod
def get_service_by_name(kb_name: str
) -> KBService:
kb_name, vs_type, embed_model = load_kb_from_db(kb_name)
_, vs_type, embed_model = load_kb_from_db(kb_name)
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
vs_type = "faiss"
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@staticmethod