mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 21:37:20 +08:00
update knowledge base api:
1. list_kbs_from_db return all kbs by default instead of return kbs with file_count > 0 only. 2. KBServiceFactory.get_service_by_name could return a FaissKBService that not in the db
This commit is contained in:
parent
323fc13d4c
commit
997f8b2e3f
@ -13,8 +13,8 @@ def add_kb_to_db(session, kb_name, vs_type, embed_model):
|
||||
|
||||
|
||||
@with_session
|
||||
def list_kbs_from_db(session):
|
||||
kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > 0).all()
|
||||
def list_kbs_from_db(session, min_file_count: int = -1):
|
||||
kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > min_file_count).all()
|
||||
kbs = [kb[0] for kb in kbs]
|
||||
return kbs
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ async def list_kbs():
|
||||
return ListResponse(data=list_kbs_from_db())
|
||||
|
||||
|
||||
async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
||||
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
vector_store_type: str = Body("faiss"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
):
|
||||
@ -22,15 +22,17 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
||||
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is not None:
|
||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||
kb.create()
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||
kb.create_kb()
|
||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||
|
||||
|
||||
async def delete_kb(
|
||||
knowledge_base_name: str = Body(..., examples=["kb_name"])
|
||||
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||
):
|
||||
# Delete selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
|
||||
@ -12,7 +12,7 @@ from typing import Union
|
||||
|
||||
|
||||
async def list_docs(
|
||||
knowledge_base_name: str = Body(..., examples=["kb_name"])
|
||||
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||
):
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||
@ -60,7 +60,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
||||
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
||||
|
||||
|
||||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
||||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
doc_name: str = Body(..., examples=["file_name"]),
|
||||
delete_content: bool = Body(False),
|
||||
):
|
||||
@ -82,7 +82,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
||||
|
||||
|
||||
async def update_doc(
|
||||
knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_name: str = Body(..., examples=["file_name"]),
|
||||
):
|
||||
'''
|
||||
@ -111,7 +111,7 @@ async def download_doc():
|
||||
|
||||
|
||||
async def recreate_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body("faiss"),
|
||||
):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
@ -42,7 +41,7 @@ class KBService(ABC):
|
||||
"""
|
||||
if not os.path.exists(self.doc_path):
|
||||
os.makedirs(self.doc_path)
|
||||
self.do_create_kb()
|
||||
self.do_create_kb()
|
||||
status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model)
|
||||
return status
|
||||
|
||||
@ -194,12 +193,15 @@ class KBServiceFactory:
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||
return DefaultKBService(kb_name)
|
||||
|
||||
@staticmethod
|
||||
def get_service_by_name(kb_name: str
|
||||
) -> KBService:
|
||||
kb_name, vs_type, embed_model = load_kb_from_db(kb_name)
|
||||
_, vs_type, embed_model = load_kb_from_db(kb_name)
|
||||
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
|
||||
vs_type = "faiss"
|
||||
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user