1、修改知识库列表接口,返回全量属性字段,同时修改受影响的相关代码。 (#4119)

2、run_in_process_pool改为run_in_thread_pool,解决兼容性问题。
3、poetry配置文件修复。
This commit is contained in:
srszzw 2024-06-01 18:44:06 +08:00 committed by GitHub
parent 67ed340c3b
commit 10c5dcfdde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 81 additions and 64 deletions

View File

@ -46,7 +46,7 @@ def search_knowledgebase(query: str, database: str, config: dict):
@regist_tool(description=template_knowledge, title="本地知识库") @regist_tool(description=template_knowledge, title="本地知识库")
def search_local_knowledgebase( def search_local_knowledgebase(
database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data), database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]),
query: str = Field(description="Query for Knowledge Search"), query: str = Field(description="Query for Knowledge Search"),
): ):
'''''' ''''''

View File

@ -1,5 +1,7 @@
from sqlalchemy import Column, Integer, String, DateTime, func from sqlalchemy import Column, Integer, String, DateTime, func
from pydantic import BaseModel
from typing import Optional
from datetime import datetime
from chatchat.server.db.base import Base from chatchat.server.db.base import Base
@ -18,3 +20,16 @@ class KnowledgeBaseModel(Base):
def __repr__(self): def __repr__(self):
return f"<KnowledgeBase(id='{self.id}', kb_name='{self.kb_name}',kb_intro='{self.kb_info} vs_type='{self.vs_type}', embed_model='{self.embed_model}', file_count='{self.file_count}', create_time='{self.create_time}')>" return f"<KnowledgeBase(id='{self.id}', kb_name='{self.kb_name}',kb_intro='{self.kb_info} vs_type='{self.vs_type}', embed_model='{self.embed_model}', file_count='{self.file_count}', create_time='{self.create_time}')>"
# 创建一个对应的 Pydantic 模型
class KnowledgeBaseSchema(BaseModel):
id: int
kb_name: str
kb_info: Optional[str]
vs_type: Optional[str]
embed_model: Optional[str]
file_count: Optional[int]
create_time: Optional[datetime]
class Config:
from_attributes = True # 确保可以从 ORM 实例进行验证

View File

@ -1,4 +1,5 @@
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseModel from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseModel
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
from chatchat.server.db.session import with_session from chatchat.server.db.session import with_session
@ -18,8 +19,8 @@ def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
@with_session @with_session
def list_kbs_from_db(session, min_file_count: int = -1): def list_kbs_from_db(session, min_file_count: int = -1):
kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > min_file_count).all() kbs = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.file_count > min_file_count).all()
kbs = [kb[0] for kb in kbs] kbs = [KnowledgeBaseSchema.model_validate(kb) for kb in kbs]
return kbs return kbs

View File

@ -25,7 +25,7 @@ from chatchat.server.knowledge_base.utils import (
from typing import List, Union, Dict, Optional, Tuple from typing import List, Union, Dict, Optional, Tuple
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
class SupportedVSType: class SupportedVSType:
FAISS = 'faiss' FAISS = 'faiss'
@ -325,7 +325,7 @@ class KBServiceFactory:
def get_kb_details() -> List[Dict]: def get_kb_details() -> List[Dict]:
kbs_in_folder = list_kbs_from_folder() kbs_in_folder = list_kbs_from_folder()
kbs_in_db = KBService.list_kbs() kbs_in_db:List[KnowledgeBaseSchema] = KBService.list_kbs()
result = {} result = {}
for kb in kbs_in_folder: for kb in kbs_in_folder:
@ -340,15 +340,16 @@ def get_kb_details() -> List[Dict]:
"in_db": False, "in_db": False,
} }
for kb in kbs_in_db: for kb_detail in kbs_in_db:
kb_detail = get_kb_detail(kb) kb_detail=kb_detail.model_dump()
if kb_detail: kb_name=kb_detail["kb_name"]
kb_detail["in_db"] = True kb_detail["in_db"] = True
if kb in result: if kb_name in result:
result[kb].update(kb_detail) result[kb_name].update(kb_detail)
else: else:
kb_detail["in_folder"] = False kb_detail["in_folder"] = False
result[kb] = kb_detail result[kb_name] = kb_detail
data = [] data = []
for i, v in enumerate(result.values()): for i, v in enumerate(result.values()):

View File

@ -404,7 +404,7 @@ def files2docs_in_thread(
except Exception as e: except Exception as e:
yield False, (kb_name, filename, str(e)) yield False, (kb_name, filename, str(e))
for result in run_in_process_pool(func=files2docs_in_thread_file2docs, params=kwargs_list): for result in run_in_thread_pool(func=files2docs_in_thread_file2docs, params=kwargs_list):
yield result yield result

View File

@ -302,7 +302,7 @@ class BaseResponse(BaseModel):
class ListResponse(BaseResponse): class ListResponse(BaseResponse):
data: List[str] = Field(..., description="List of names") data: List[Any] = Field(..., description="List of data")
class Config: class Config:
json_schema_extra = { json_schema_extra = {

View File

@ -62,53 +62,6 @@ xinference = ["xinference_client"]
zhipuai = ["zhipuai"] zhipuai = ["zhipuai"]
cli = ["typer"] cli = ["typer"]
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
# The only dependencies that should be added are
# dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.9.2"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
responses = "^0.22.0"
pytest-asyncio = "^0.23.2"
lark = "^1.1.5"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
syrupy = "^4.0.2"
requests-mock = "^1.11.0"
model-providers = { path = "../model-providers", develop = true }
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
setuptools = "^67.6.1"
model-providers = { path = "../model-providers", develop = true }
# An extra used to be able to add extended testing. # An extra used to be able to add extended testing.
# Please use new-line on formatting to make it easier to add new packages without # Please use new-line on formatting to make it easier to add new packages without
# merge-conflicts # merge-conflicts
@ -194,6 +147,53 @@ extended_testing = [
"friendli-client" "friendli-client"
] ]
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
# The only dependencies that should be added are
# dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.9.2"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
responses = "^0.22.0"
pytest-asyncio = "^0.23.2"
lark = "^1.1.5"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
syrupy = "^4.0.2"
requests-mock = "^1.11.0"
model-providers = { path = "../model-providers", develop = true }
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
setuptools = "^67.6.1"
model-providers = { path = "../model-providers", develop = true }
[tool.ruff] [tool.ruff]
exclude = [ exclude = [
"tests/examples/non-utf8-encoding.py", "tests/examples/non-utf8-encoding.py",