mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-04 21:53:14 +08:00
update knowledge base api:
1. list_kbs_from_db return all kbs by default instead of return kbs with file_count > 0 only. 2. KBServiceFactory.get_service_by_name could return a FaissKBService that not in the db
This commit is contained in:
parent
323fc13d4c
commit
997f8b2e3f
@ -13,8 +13,8 @@ def add_kb_to_db(session, kb_name, vs_type, embed_model):
|
|||||||
|
|
||||||
|
|
||||||
@with_session
|
@with_session
|
||||||
def list_kbs_from_db(session):
|
def list_kbs_from_db(session, min_file_count: int = -1):
|
||||||
kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > 0).all()
|
kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > min_file_count).all()
|
||||||
kbs = [kb[0] for kb in kbs]
|
kbs = [kb[0] for kb in kbs]
|
||||||
return kbs
|
return kbs
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ async def list_kbs():
|
|||||||
return ListResponse(data=list_kbs_from_db())
|
return ListResponse(data=list_kbs_from_db())
|
||||||
|
|
||||||
|
|
||||||
async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
vector_store_type: str = Body("faiss"),
|
vector_store_type: str = Body("faiss"),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL),
|
embed_model: str = Body(EMBEDDING_MODEL),
|
||||||
):
|
):
|
||||||
@ -22,15 +22,17 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
|||||||
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
||||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||||
|
|
||||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
if kb is not None:
|
if kb is not None:
|
||||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||||
kb.create()
|
|
||||||
|
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||||
|
kb.create_kb()
|
||||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
|
|
||||||
async def delete_kb(
|
async def delete_kb(
|
||||||
knowledge_base_name: str = Body(..., examples=["kb_name"])
|
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||||
):
|
):
|
||||||
# Delete selected knowledge base
|
# Delete selected knowledge base
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from typing import Union
|
|||||||
|
|
||||||
|
|
||||||
async def list_docs(
|
async def list_docs(
|
||||||
knowledge_base_name: str = Body(..., examples=["kb_name"])
|
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||||
):
|
):
|
||||||
if not validate_kb_name(knowledge_base_name):
|
if not validate_kb_name(knowledge_base_name):
|
||||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||||
@ -60,7 +60,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||||||
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
return BaseResponse(code=200, msg=f"成功上传文件 {kb_file.filename}")
|
||||||
|
|
||||||
|
|
||||||
async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
async def delete_doc(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
doc_name: str = Body(..., examples=["file_name"]),
|
doc_name: str = Body(..., examples=["file_name"]),
|
||||||
delete_content: bool = Body(False),
|
delete_content: bool = Body(False),
|
||||||
):
|
):
|
||||||
@ -82,7 +82,7 @@ async def delete_doc(knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
|||||||
|
|
||||||
|
|
||||||
async def update_doc(
|
async def update_doc(
|
||||||
knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
file_name: str = Body(..., examples=["file_name"]),
|
file_name: str = Body(..., examples=["file_name"]),
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
@ -111,7 +111,7 @@ async def download_doc():
|
|||||||
|
|
||||||
|
|
||||||
async def recreate_vector_store(
|
async def recreate_vector_store(
|
||||||
knowledge_base_name: str = Body(..., examples=["kb_name"]),
|
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
allow_empty_kb: bool = Body(True),
|
allow_empty_kb: bool = Body(True),
|
||||||
vs_type: str = Body("faiss"),
|
vs_type: str = Body("faiss"),
|
||||||
):
|
):
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
@ -42,7 +41,7 @@ class KBService(ABC):
|
|||||||
"""
|
"""
|
||||||
if not os.path.exists(self.doc_path):
|
if not os.path.exists(self.doc_path):
|
||||||
os.makedirs(self.doc_path)
|
os.makedirs(self.doc_path)
|
||||||
self.do_create_kb()
|
self.do_create_kb()
|
||||||
status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model)
|
status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
@ -194,12 +193,15 @@ class KBServiceFactory:
|
|||||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||||
|
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||||
return DefaultKBService(kb_name)
|
return DefaultKBService(kb_name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_service_by_name(kb_name: str
|
def get_service_by_name(kb_name: str
|
||||||
) -> KBService:
|
) -> KBService:
|
||||||
kb_name, vs_type, embed_model = load_kb_from_db(kb_name)
|
_, vs_type, embed_model = load_kb_from_db(kb_name)
|
||||||
|
if vs_type is None and os.path.isdir(get_kb_path(kb_name)): # faiss knowledge base not in db
|
||||||
|
vs_type = "faiss"
|
||||||
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user