优化知识库相关功能 (#4153)

- 新功能
    - pypi 包新增 chatchat-kb 命令脚本,对应 init_database.py 功能

- 开发者
    - _model_config.py 中默认包含 xinference 配置项
    - 所有涉及向量库的操作,前置检查当前 Embed 模型是否可用
    - /knowledge_base/create_knowledge_base 接口增加 kb_info 参数
    - /knowledge_base/list_files 接口返回所有数据库字段,而非文件名称列表
    - 修正 xinference 模型管理脚本
This commit is contained in:
liunux4odoo 2024-06-08 14:34:50 +08:00 committed by GitHub
parent 8994b25a77
commit a5b203170b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 197 additions and 124 deletions

View File

@ -29,6 +29,11 @@ vim model_providers.yaml
>
> 详细配置请参考[README.md](../model-providers/README.md)
- 初始化知识库
```shell
chatchat-kb -r
```
- 启动服务
```shell
chatchat -a

View File

@ -118,6 +118,25 @@ MODEL_PLATFORMS = [
"tts_models": [],
},
{
"platform_name": "xinference",
"platform_type": "xinference",
"api_base_url": "http://127.0.0.1:9997/v1",
"api_key": "EMPTY",
"api_concurrencies": 5,
"llm_models": [
"glm-4",
"qwen2-instruct",
"qwen1.5-chat",
],
"embed_models": [
"bge-large-zh-v1.5",
],
"image_models": [],
"reranking_models": [],
"speech2text_models": [],
"tts_models": [],
},
]

View File

@ -20,16 +20,16 @@
xinference:
model_credential:
- model: 'chatglm3-6b'
- model: 'glm-4'
model_type: 'llm'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b'
- model: 'Qwen1.5-14B-Chat'
model_uid: 'glm-4'
- model: 'qwen1.5-chat'
model_type: 'llm'
model_credentials:
server_url: 'http://127.0.0.1:9997/'
model_uid: 'Qwen1.5-14B-Chat'
model_uid: 'qwen1.5-chat'
- model: 'bge-large-zh-v1.5'
model_type: 'embeddings'
model_credentials:
@ -46,4 +46,4 @@ xinference:
# model_type: 'llm'
# model_credentials:
# base_url: 'http://172.21.192.1:11434'
# mode: 'completion'
# mode: 'completion'

View File

@ -1,13 +1,12 @@
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
from datetime import datetime
import multiprocessing as mp
from typing import Dict
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
folder2db, prune_db_docs, prune_folder_files)
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS
import multiprocessing as mp
import logging
logger = logging.getLogger(__name__)
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS, logger
from datetime import datetime
def run_init_model_provider(
@ -34,7 +33,7 @@ def run_init_model_provider(
provider_port=provider_port)
if __name__ == "__main__":
def main():
import argparse
parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
@ -186,3 +185,7 @@ if __name__ == "__main__":
for p in processes.values():
logger.info("Process status: %s", p)
if __name__ == "__main__":
main()

View File

@ -14,6 +14,7 @@ def list_kbs():
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
kb_info: str = Body("", description="知识库内容简介用于Agent选择知识库。"),
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
) -> BaseResponse:
# Create selected knowledge base
@ -26,7 +27,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
if kb is not None:
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info)
try:
kb.create_kb()
except Exception as e:

View File

@ -62,7 +62,7 @@ class CachePool:
self._cache_num = cache_num
self._cache = OrderedDict()
self.atomic = threading.RLock()
def keys(self) -> List[str]:
return list(self._cache.keys())

View File

@ -95,6 +95,7 @@ class KBFaissPool(_FaissPool):
embed_model: str = DEFAULT_EMBEDDING_MODEL,
) -> ThreadSafeFaiss:
self.atomic.acquire()
locked = True
vector_name = vector_name or embed_model
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
try:
@ -103,6 +104,7 @@ class KBFaissPool(_FaissPool):
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
locked = False
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name)
@ -121,8 +123,10 @@ class KBFaissPool(_FaissPool):
item.finish_loading()
else:
self.atomic.release()
locked = False
except Exception as e:
self.atomic.release()
if locked: # we don't know exception raised before or after atomic.release
self.atomic.release()
logger.error(e)
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
return self.get((kb_name, vector_name))

View File

@ -1,21 +1,23 @@
import json
import os
import urllib
from typing import List, Dict
from fastapi import File, Form, Body, Query, UploadFile
from fastapi.responses import FileResponse
from langchain.docstore.document import Document
from sse_starlette import EventSourceResponse
from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
logger, log_verbose, )
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
files2docs_in_thread, KnowledgeFile)
from fastapi.responses import FileResponse
from sse_starlette import EventSourceResponse
import json
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
from langchain.docstore.document import Document
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, get_kb_file_details
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from typing import List, Dict
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool, check_embed_model
def search_docs(
@ -55,8 +57,8 @@ def list_files(
if kb is None:
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = kb.list_files()
return ListResponse(data=all_doc_names)
all_docs = get_kb_file_details(knowledge_base_name)
return ListResponse(data=all_docs)
def _save_files_in_thread(files: List[UploadFile],
@ -352,38 +354,42 @@ def recreate_vector_store(
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
if kb.exists():
kb.clear_vs()
kb.create_kb()
files = list_files_from_folder(knowledge_base_name)
kb_files = [(file, knowledge_base_name) for file in files]
i = 0
for status, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if status:
kb_name, file_name, docs = result
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
kb_file.splited_docs = docs
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
}, ensure_ascii=False)
kb.add_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
i += 1
if not not_refresh_vs_cache:
kb.save_vector_store()
error_msg = f"could not recreate vector store because failed to access embed model."
if not kb.check_embed_model(error_msg):
yield {"code": 404, "msg": error_msg}
else:
if kb.exists():
kb.clear_vs()
kb.create_kb()
files = list_files_from_folder(knowledge_base_name)
kb_files = [(file, knowledge_base_name) for file in files]
i = 0
for status, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if status:
kb_name, file_name, docs = result
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
kb_file.splited_docs = docs
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
}, ensure_ascii=False)
kb.add_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
i += 1
if not not_refresh_vs_cache:
kb.save_vector_store()
return EventSourceResponse(output())

View File

@ -5,6 +5,11 @@ import os
from pathlib import Path
from langchain.docstore.document import Document
from typing import List, Union, Dict, Optional, Tuple
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
DEFAULT_EMBEDDING_MODEL, KB_INFO, logger)
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
from chatchat.server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail,
@ -14,18 +19,12 @@ from chatchat.server.db.repository.knowledge_file_repository import (
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
list_docs_from_db,
)
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
DEFAULT_EMBEDDING_MODEL, KB_INFO)
from chatchat.server.knowledge_base.utils import (
get_kb_path, get_doc_path, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder,
)
from typing import List, Union, Dict, Optional, Tuple
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
from chatchat.server.utils import check_embed_model as _check_embed_model
class SupportedVSType:
FAISS = 'faiss'
@ -41,10 +40,11 @@ class KBService(ABC):
def __init__(self,
knowledge_base_name: str,
kb_info: str = None,
embed_model: str = DEFAULT_EMBEDDING_MODEL,
):
self.kb_name = knowledge_base_name
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
self.kb_info = kb_info or KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name)
@ -59,6 +59,13 @@ class KBService(ABC):
'''
pass
def check_embed_model(self, error_msg: str) -> bool:
if not _check_embed_model(self.embed_model):
logger.error(error_msg, exc_info=True)
return False
else:
return True
def create_kb(self):
"""
创建知识库
@ -93,6 +100,9 @@ class KBService(ABC):
向知识库添加文件
如果指定了docs则不再将文本向量化并将数据库对应条目标为custom_docs=True
"""
if not self.check_embed_model(f"could not add docs because failed to access embed model."):
return False
if docs:
custom_docs = True
else:
@ -143,6 +153,9 @@ class KBService(ABC):
使用content中的文件更新向量库
如果指定了docs则使用自定义docs并将数据库对应条目标为custom_docs=True
"""
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
return False
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, docs=docs, **kwargs)
@ -162,6 +175,8 @@ class KBService(ABC):
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
) ->List[Document]:
if not self.check_embed_model(f"could not search docs because failed to access embed model."):
return []
docs = self.do_search(query, top_k, score_threshold)
return docs
@ -176,6 +191,9 @@ class KBService(ABC):
传入参数为 {doc_id: Document, ...}
如果对应 doc_id 的值为 None或其 page_content 为空则删除该文档
'''
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
return False
self.del_doc_by_ids(list(docs.keys()))
docs = []
ids = []
@ -282,31 +300,32 @@ class KBServiceFactory:
def get_service(kb_name: str,
vector_store_type: Union[str, SupportedVSType],
embed_model: str = DEFAULT_EMBEDDING_MODEL,
kb_info: str = None,
) -> KBService:
if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
params = {"knowledge_base_name": kb_name, "embed_model": embed_model, "kb_info": kb_info}
if SupportedVSType.FAISS == vector_store_type:
from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name, embed_model=embed_model)
return FaissKBService(**params)
elif SupportedVSType.PG == vector_store_type:
from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService
return PGKBService(kb_name, embed_model=embed_model)
return PGKBService(**params)
elif SupportedVSType.MILVUS == vector_store_type:
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name, embed_model=embed_model)
return MilvusKBService(**params)
elif SupportedVSType.ZILLIZ == vector_store_type:
from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
return ZillizKBService(kb_name, embed_model=embed_model)
return ZillizKBService(**params)
elif SupportedVSType.DEFAULT == vector_store_type:
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
return MilvusKBService(**params) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.ES == vector_store_type:
from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model)
return ESKBService(**params)
elif SupportedVSType.CHROMADB == vector_store_type:
from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
return ChromaKBService(kb_name, embed_model=embed_model)
return ChromaKBService(**params)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)

View File

@ -42,55 +42,59 @@ def recreate_summary_vector_store(
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.drop_kb_summary()
kb_summary.create_kb_summary()
error_msg = f"could not recreate summary vector store because failed to access embed model."
if not kb.check_embed_model(error_msg):
yield {"code": 404, "msg": error_msg}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.drop_kb_summary()
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
local_wrap=True,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
local_wrap=True,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
files = list_files_from_folder(knowledge_base_name)
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
local_wrap=True,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
local_wrap=True,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
files = list_files_from_folder(knowledge_base_name)
i = 0
for i, file_name in enumerate(files):
i = 0
for i, file_name in enumerate(files):
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description,
docs=doc_infos)
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
}, ensure_ascii=False)
else:
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
i += 1
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
i += 1
return EventSourceResponse(output())

View File

@ -1,6 +1,6 @@
from chatchat.configs import (
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
CHUNK_SIZE, OVERLAP_SIZE
CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose
)
from chatchat.server.knowledge_base.utils import (
get_file_path, list_kbs_from_folder,

View File

@ -244,6 +244,16 @@ def get_Embeddings(
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
def check_embed_model(embed_model: str=DEFAULT_EMBEDDING_MODEL) -> bool:
embeddings = get_Embeddings(embed_model=embed_model)
try:
embeddings.embed_query("this is a test")
return True
except Exception as e:
logger.error(f"failed to access embed model '{embed_model}': {e}", exc_info=True)
return False
def get_OpenAIClient(
platform_name: str = None,
model_name: str = None,
@ -696,7 +706,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
'''
创建一个临时目录返回路径文件夹名称
'''
from chatchat.configs.basic_config import BASE_TEMP_DIR
from chatchat.configs import BASE_TEMP_DIR
import uuid
if id is not None: # 如果指定的临时目录已存在,直接返回

View File

@ -10,6 +10,7 @@ packages = [
[tool.poetry.scripts]
chatchat = 'chatchat.startup:main'
chatchat-kb = 'chatchat.init_database:main'
[tool.poetry.dependencies]
python = ">=3.8.1,<3.12,!=3.9.7"

View File

@ -133,9 +133,11 @@ model_names = list(regs[model_type].keys())
model_name = cols[1].selectbox("模型名称:", model_names)
cur_reg = regs[model_type][model_name]["reg"]
model_format = None
model_quant = None
if model_type == "LLM":
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
cur_family = xf_llm.LLMFamilyV1.model_validate(cur_reg)
cur_spec = None
model_formats = []
for spec in cur_reg["model_specs"]:
@ -217,7 +219,7 @@ cols = st.columns(3)
if cols[0].button("设置模型缓存"):
if os.path.isabs(local_path) and os.path.isdir(local_path):
cur_spec.model_uri = local_path
cur_spec.__dict__["model_uri"] = local_path # embedding spec has no attribute model_uri
if os.path.isdir(cache_dir):
os.rmdir(cache_dir)
if model_type == "LLM":
@ -250,10 +252,10 @@ if cols[2].button("注册为自定义模型"):
if model_type == "LLM":
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
cur_family.model_family = "other"
model_definition = cur_family.json(indent=2, ensure_ascii=False)
model_definition = cur_family.model_dump_json(indent=2, ensure_ascii=False)
else:
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
model_definition = cur_spec.json(indent=2, ensure_ascii=False)
model_definition = cur_spec.model_dump_json(indent=2, ensure_ascii=False)
client.register_model(
model_type=model_type,
model=model_definition,
@ -262,4 +264,3 @@ if cols[2].button("注册为自定义模型"):
st.rerun()
else:
st.error("必须输入存在的绝对路径")