优化知识库相关功能 (#4153)

- 新功能
    - pypi 包新增 chatchat-kb 命令脚本,对应 init_database.py 功能

- 开发者
    - _model_config.py 中默认包含 xinference 配置项
    - 所有涉及向量库的操作,前置检查当前 Embed 模型是否可用
    - /knowledge_base/create_knowledge_base 接口增加 kb_info 参数
    - /knowledge_base/list_files 接口返回所有数据库字段,而非文件名称列表
    - 修正 xinference 模型管理脚本
This commit is contained in:
liunux4odoo 2024-06-08 14:34:50 +08:00 committed by GitHub
parent 8994b25a77
commit a5b203170b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 197 additions and 124 deletions

View File

@ -29,6 +29,11 @@ vim model_providers.yaml
> >
> 详细配置请参考[README.md](../model-providers/README.md) > 详细配置请参考[README.md](../model-providers/README.md)
- 初始化知识库
```shell
chatchat-kb -r
```
- 启动服务 - 启动服务
```shell ```shell
chatchat -a chatchat -a

View File

@ -118,6 +118,25 @@ MODEL_PLATFORMS = [
"tts_models": [], "tts_models": [],
}, },
{
"platform_name": "xinference",
"platform_type": "xinference",
"api_base_url": "http://127.0.0.1:9997/v1",
"api_key": "EMPTY",
"api_concurrencies": 5,
"llm_models": [
"glm-4",
"qwen2-instruct",
"qwen1.5-chat",
],
"embed_models": [
"bge-large-zh-v1.5",
],
"image_models": [],
"reranking_models": [],
"speech2text_models": [],
"tts_models": [],
},
] ]

View File

@ -20,16 +20,16 @@
xinference: xinference:
model_credential: model_credential:
- model: 'chatglm3-6b' - model: 'glm-4'
model_type: 'llm' model_type: 'llm'
model_credentials: model_credentials:
server_url: 'http://127.0.0.1:9997/' server_url: 'http://127.0.0.1:9997/'
model_uid: 'chatglm3-6b' model_uid: 'glm-4'
- model: 'Qwen1.5-14B-Chat' - model: 'qwen1.5-chat'
model_type: 'llm' model_type: 'llm'
model_credentials: model_credentials:
server_url: 'http://127.0.0.1:9997/' server_url: 'http://127.0.0.1:9997/'
model_uid: 'Qwen1.5-14B-Chat' model_uid: 'qwen1.5-chat'
- model: 'bge-large-zh-v1.5' - model: 'bge-large-zh-v1.5'
model_type: 'embeddings' model_type: 'embeddings'
model_credentials: model_credentials:
@ -46,4 +46,4 @@ xinference:
# model_type: 'llm' # model_type: 'llm'
# model_credentials: # model_credentials:
# base_url: 'http://172.21.192.1:11434' # base_url: 'http://172.21.192.1:11434'
# mode: 'completion' # mode: 'completion'

View File

@ -1,13 +1,12 @@
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作 # Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
from datetime import datetime
import multiprocessing as mp
from typing import Dict from typing import Dict
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
folder2db, prune_db_docs, prune_folder_files) folder2db, prune_db_docs, prune_folder_files)
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS, logger
import multiprocessing as mp
import logging
logger = logging.getLogger(__name__)
from datetime import datetime
def run_init_model_provider( def run_init_model_provider(
@ -34,7 +33,7 @@ def run_init_model_provider(
provider_port=provider_port) provider_port=provider_port)
if __name__ == "__main__": def main():
import argparse import argparse
parser = argparse.ArgumentParser(description="please specify only one operate method once time.") parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
@ -186,3 +185,7 @@ if __name__ == "__main__":
for p in processes.values(): for p in processes.values():
logger.info("Process status: %s", p) logger.info("Process status: %s", p)
if __name__ == "__main__":
main()

View File

@ -14,6 +14,7 @@ def list_kbs():
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"), vector_store_type: str = Body("faiss"),
kb_info: str = Body("", description="知识库内容简介用于Agent选择知识库。"),
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
) -> BaseResponse: ) -> BaseResponse:
# Create selected knowledge base # Create selected knowledge base
@ -26,7 +27,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
if kb is not None: if kb is not None:
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info)
try: try:
kb.create_kb() kb.create_kb()
except Exception as e: except Exception as e:

View File

@ -62,7 +62,7 @@ class CachePool:
self._cache_num = cache_num self._cache_num = cache_num
self._cache = OrderedDict() self._cache = OrderedDict()
self.atomic = threading.RLock() self.atomic = threading.RLock()
def keys(self) -> List[str]: def keys(self) -> List[str]:
return list(self._cache.keys()) return list(self._cache.keys())

View File

@ -95,6 +95,7 @@ class KBFaissPool(_FaissPool):
embed_model: str = DEFAULT_EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
) -> ThreadSafeFaiss: ) -> ThreadSafeFaiss:
self.atomic.acquire() self.atomic.acquire()
locked = True
vector_name = vector_name or embed_model vector_name = vector_name or embed_model
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
try: try:
@ -103,6 +104,7 @@ class KBFaissPool(_FaissPool):
self.set((kb_name, vector_name), item) self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"): with item.acquire(msg="初始化"):
self.atomic.release() self.atomic.release()
locked = False
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.") logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name) vs_path = get_vs_path(kb_name, vector_name)
@ -121,8 +123,10 @@ class KBFaissPool(_FaissPool):
item.finish_loading() item.finish_loading()
else: else:
self.atomic.release() self.atomic.release()
locked = False
except Exception as e: except Exception as e:
self.atomic.release() if locked: # we don't know exception raised before or after atomic.release
self.atomic.release()
logger.error(e) logger.error(e)
raise RuntimeError(f"向量库 {kb_name} 加载失败。") raise RuntimeError(f"向量库 {kb_name} 加载失败。")
return self.get((kb_name, vector_name)) return self.get((kb_name, vector_name))

View File

@ -1,21 +1,23 @@
import json
import os import os
import urllib import urllib
from typing import List, Dict
from fastapi import File, Form, Body, Query, UploadFile from fastapi import File, Form, Body, Query, UploadFile
from fastapi.responses import FileResponse
from langchain.docstore.document import Document
from sse_starlette import EventSourceResponse
from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL, from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
logger, log_verbose, ) logger, log_verbose, )
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path, from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
files2docs_in_thread, KnowledgeFile) files2docs_in_thread, KnowledgeFile)
from fastapi.responses import FileResponse from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, get_kb_file_details
from sse_starlette import EventSourceResponse
import json
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
from langchain.docstore.document import Document
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from typing import List, Dict from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool, check_embed_model
def search_docs( def search_docs(
@ -55,8 +57,8 @@ def list_files(
if kb is None: if kb is None:
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else: else:
all_doc_names = kb.list_files() all_docs = get_kb_file_details(knowledge_base_name)
return ListResponse(data=all_doc_names) return ListResponse(data=all_docs)
def _save_files_in_thread(files: List[UploadFile], def _save_files_in_thread(files: List[UploadFile],
@ -352,38 +354,42 @@ def recreate_vector_store(
if not kb.exists() and not allow_empty_kb: if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"} yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else: else:
if kb.exists(): error_msg = f"could not recreate vector store because failed to access embed model."
kb.clear_vs() if not kb.check_embed_model(error_msg):
kb.create_kb() yield {"code": 404, "msg": error_msg}
files = list_files_from_folder(knowledge_base_name) else:
kb_files = [(file, knowledge_base_name) for file in files] if kb.exists():
i = 0 kb.clear_vs()
for status, result in files2docs_in_thread(kb_files, kb.create_kb()
chunk_size=chunk_size, files = list_files_from_folder(knowledge_base_name)
chunk_overlap=chunk_overlap, kb_files = [(file, knowledge_base_name) for file in files]
zh_title_enhance=zh_title_enhance): i = 0
if status: for status, result in files2docs_in_thread(kb_files,
kb_name, file_name, docs = result chunk_size=chunk_size,
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name) chunk_overlap=chunk_overlap,
kb_file.splited_docs = docs zh_title_enhance=zh_title_enhance):
yield json.dumps({ if status:
"code": 200, kb_name, file_name, docs = result
"msg": f"({i + 1} / {len(files)}): {file_name}", kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
"total": len(files), kb_file.splited_docs = docs
"finished": i + 1, yield json.dumps({
"doc": file_name, "code": 200,
}, ensure_ascii=False) "msg": f"({i + 1} / {len(files)}): {file_name}",
kb.add_doc(kb_file, not_refresh_vs_cache=True) "total": len(files),
else: "finished": i + 1,
kb_name, file_name, error = result "doc": file_name,
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" }, ensure_ascii=False)
logger.error(msg) kb.add_doc(kb_file, not_refresh_vs_cache=True)
yield json.dumps({ else:
"code": 500, kb_name, file_name, error = result
"msg": msg, msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
}) logger.error(msg)
i += 1 yield json.dumps({
if not not_refresh_vs_cache: "code": 500,
kb.save_vector_store() "msg": msg,
})
i += 1
if not not_refresh_vs_cache:
kb.save_vector_store()
return EventSourceResponse(output()) return EventSourceResponse(output())

View File

@ -5,6 +5,11 @@ import os
from pathlib import Path from pathlib import Path
from langchain.docstore.document import Document from langchain.docstore.document import Document
from typing import List, Union, Dict, Optional, Tuple
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
DEFAULT_EMBEDDING_MODEL, KB_INFO, logger)
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
from chatchat.server.db.repository.knowledge_base_repository import ( from chatchat.server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail, load_kb_from_db, get_kb_detail,
@ -14,18 +19,12 @@ from chatchat.server.db.repository.knowledge_file_repository import (
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db, count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
list_docs_from_db, list_docs_from_db,
) )
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
DEFAULT_EMBEDDING_MODEL, KB_INFO)
from chatchat.server.knowledge_base.utils import ( from chatchat.server.knowledge_base.utils import (
get_kb_path, get_doc_path, KnowledgeFile, get_kb_path, get_doc_path, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder, list_kbs_from_folder, list_files_from_folder,
) )
from typing import List, Union, Dict, Optional, Tuple
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema from chatchat.server.utils import check_embed_model as _check_embed_model
class SupportedVSType: class SupportedVSType:
FAISS = 'faiss' FAISS = 'faiss'
@ -41,10 +40,11 @@ class KBService(ABC):
def __init__(self, def __init__(self,
knowledge_base_name: str, knowledge_base_name: str,
kb_info: str = None,
embed_model: str = DEFAULT_EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
): ):
self.kb_name = knowledge_base_name self.kb_name = knowledge_base_name
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") self.kb_info = kb_info or KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
self.embed_model = embed_model self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name) self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name)
@ -59,6 +59,13 @@ class KBService(ABC):
''' '''
pass pass
def check_embed_model(self, error_msg: str) -> bool:
if not _check_embed_model(self.embed_model):
logger.error(error_msg, exc_info=True)
return False
else:
return True
def create_kb(self): def create_kb(self):
""" """
创建知识库 创建知识库
@ -93,6 +100,9 @@ class KBService(ABC):
向知识库添加文件 向知识库添加文件
如果指定了docs则不再将文本向量化并将数据库对应条目标为custom_docs=True 如果指定了docs则不再将文本向量化并将数据库对应条目标为custom_docs=True
""" """
if not self.check_embed_model(f"could not add docs because failed to access embed model."):
return False
if docs: if docs:
custom_docs = True custom_docs = True
else: else:
@ -143,6 +153,9 @@ class KBService(ABC):
使用content中的文件更新向量库 使用content中的文件更新向量库
如果指定了docs则使用自定义docs并将数据库对应条目标为custom_docs=True 如果指定了docs则使用自定义docs并将数据库对应条目标为custom_docs=True
""" """
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
return False
if os.path.exists(kb_file.filepath): if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs) self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, docs=docs, **kwargs) return self.add_doc(kb_file, docs=docs, **kwargs)
@ -162,6 +175,8 @@ class KBService(ABC):
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD, score_threshold: float = SCORE_THRESHOLD,
) ->List[Document]: ) ->List[Document]:
if not self.check_embed_model(f"could not search docs because failed to access embed model."):
return []
docs = self.do_search(query, top_k, score_threshold) docs = self.do_search(query, top_k, score_threshold)
return docs return docs
@ -176,6 +191,9 @@ class KBService(ABC):
传入参数为 {doc_id: Document, ...} 传入参数为 {doc_id: Document, ...}
如果对应 doc_id 的值为 None或其 page_content 为空则删除该文档 如果对应 doc_id 的值为 None或其 page_content 为空则删除该文档
''' '''
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
return False
self.del_doc_by_ids(list(docs.keys())) self.del_doc_by_ids(list(docs.keys()))
docs = [] docs = []
ids = [] ids = []
@ -282,31 +300,32 @@ class KBServiceFactory:
def get_service(kb_name: str, def get_service(kb_name: str,
vector_store_type: Union[str, SupportedVSType], vector_store_type: Union[str, SupportedVSType],
embed_model: str = DEFAULT_EMBEDDING_MODEL, embed_model: str = DEFAULT_EMBEDDING_MODEL,
kb_info: str = None,
) -> KBService: ) -> KBService:
if isinstance(vector_store_type, str): if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
params = {"knowledge_base_name": kb_name, "embed_model": embed_model, "kb_info": kb_info}
if SupportedVSType.FAISS == vector_store_type: if SupportedVSType.FAISS == vector_store_type:
from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name, embed_model=embed_model) return FaissKBService(**params)
elif SupportedVSType.PG == vector_store_type: elif SupportedVSType.PG == vector_store_type:
from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService
return PGKBService(kb_name, embed_model=embed_model) return PGKBService(**params)
elif SupportedVSType.MILVUS == vector_store_type: elif SupportedVSType.MILVUS == vector_store_type:
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name, embed_model=embed_model) return MilvusKBService(**params)
elif SupportedVSType.ZILLIZ == vector_store_type: elif SupportedVSType.ZILLIZ == vector_store_type:
from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
return ZillizKBService(kb_name, embed_model=embed_model) return ZillizKBService(**params)
elif SupportedVSType.DEFAULT == vector_store_type: elif SupportedVSType.DEFAULT == vector_store_type:
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name, return MilvusKBService(**params) # other milvus parameters are set in model_config.kbs_config
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.ES == vector_store_type: elif SupportedVSType.ES == vector_store_type:
from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model) return ESKBService(**params)
elif SupportedVSType.CHROMADB == vector_store_type: elif SupportedVSType.CHROMADB == vector_store_type:
from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
return ChromaKBService(kb_name, embed_model=embed_model) return ChromaKBService(**params)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name) return DefaultKBService(kb_name)

View File

@ -42,55 +42,59 @@ def recreate_summary_vector_store(
if not kb.exists() and not allow_empty_kb: if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"} yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else: else:
# 重新创建知识库 error_msg = f"could not recreate summary vector store because failed to access embed model."
kb_summary = KBSummaryService(knowledge_base_name, embed_model) if not kb.check_embed_model(error_msg):
kb_summary.drop_kb_summary() yield {"code": 404, "msg": error_msg}
kb_summary.create_kb_summary() else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.drop_kb_summary()
kb_summary.create_kb_summary()
llm = get_ChatOpenAI( llm = get_ChatOpenAI(
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
local_wrap=True, local_wrap=True,
) )
reduce_llm = get_ChatOpenAI( reduce_llm = get_ChatOpenAI(
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
local_wrap=True, local_wrap=True,
) )
# 文本摘要适配器 # 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm, summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm, reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE) overlap_size=OVERLAP_SIZE)
files = list_files_from_folder(knowledge_base_name) files = list_files_from_folder(knowledge_base_name)
i = 0 i = 0
for i, file_name in enumerate(files): for i, file_name in enumerate(files):
doc_infos = kb.list_docs(file_name=file_name) doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description, docs = summary.summarize(file_description=file_description,
docs=doc_infos) docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary: if status_kb_summary:
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成") logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
yield json.dumps({ yield json.dumps({
"code": 200, "code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}", "msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files), "total": len(files),
"finished": i + 1, "finished": i + 1,
"doc": file_name, "doc": file_name,
}, ensure_ascii=False) }, ensure_ascii=False)
else: else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg) logger.error(msg)
yield json.dumps({ yield json.dumps({
"code": 500, "code": 500,
"msg": msg, "msg": msg,
}) })
i += 1 i += 1
return EventSourceResponse(output()) return EventSourceResponse(output())

View File

@ -1,6 +1,6 @@
from chatchat.configs import ( from chatchat.configs import (
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
CHUNK_SIZE, OVERLAP_SIZE CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose
) )
from chatchat.server.knowledge_base.utils import ( from chatchat.server.knowledge_base.utils import (
get_file_path, list_kbs_from_folder, get_file_path, list_kbs_from_folder,

View File

@ -244,6 +244,16 @@ def get_Embeddings(
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True) logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
def check_embed_model(embed_model: str=DEFAULT_EMBEDDING_MODEL) -> bool:
embeddings = get_Embeddings(embed_model=embed_model)
try:
embeddings.embed_query("this is a test")
return True
except Exception as e:
logger.error(f"failed to access embed model '{embed_model}': {e}", exc_info=True)
return False
def get_OpenAIClient( def get_OpenAIClient(
platform_name: str = None, platform_name: str = None,
model_name: str = None, model_name: str = None,
@ -696,7 +706,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
''' '''
创建一个临时目录返回路径文件夹名称 创建一个临时目录返回路径文件夹名称
''' '''
from chatchat.configs.basic_config import BASE_TEMP_DIR from chatchat.configs import BASE_TEMP_DIR
import uuid import uuid
if id is not None: # 如果指定的临时目录已存在,直接返回 if id is not None: # 如果指定的临时目录已存在,直接返回

View File

@ -10,6 +10,7 @@ packages = [
[tool.poetry.scripts] [tool.poetry.scripts]
chatchat = 'chatchat.startup:main' chatchat = 'chatchat.startup:main'
chatchat-kb = 'chatchat.init_database:main'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<3.12,!=3.9.7" python = ">=3.8.1,<3.12,!=3.9.7"

View File

@ -133,9 +133,11 @@ model_names = list(regs[model_type].keys())
model_name = cols[1].selectbox("模型名称:", model_names) model_name = cols[1].selectbox("模型名称:", model_names)
cur_reg = regs[model_type][model_name]["reg"] cur_reg = regs[model_type][model_name]["reg"]
model_format = None
model_quant = None
if model_type == "LLM": if model_type == "LLM":
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg) cur_family = xf_llm.LLMFamilyV1.model_validate(cur_reg)
cur_spec = None cur_spec = None
model_formats = [] model_formats = []
for spec in cur_reg["model_specs"]: for spec in cur_reg["model_specs"]:
@ -217,7 +219,7 @@ cols = st.columns(3)
if cols[0].button("设置模型缓存"): if cols[0].button("设置模型缓存"):
if os.path.isabs(local_path) and os.path.isdir(local_path): if os.path.isabs(local_path) and os.path.isdir(local_path):
cur_spec.model_uri = local_path cur_spec.__dict__["model_uri"] = local_path # embedding spec has no attribute model_uri
if os.path.isdir(cache_dir): if os.path.isdir(cache_dir):
os.rmdir(cache_dir) os.rmdir(cache_dir)
if model_type == "LLM": if model_type == "LLM":
@ -250,10 +252,10 @@ if cols[2].button("注册为自定义模型"):
if model_type == "LLM": if model_type == "LLM":
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}" cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
cur_family.model_family = "other" cur_family.model_family = "other"
model_definition = cur_family.json(indent=2, ensure_ascii=False) model_definition = cur_family.model_dump_json(indent=2, ensure_ascii=False)
else: else:
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}" cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
model_definition = cur_spec.json(indent=2, ensure_ascii=False) model_definition = cur_spec.model_dump_json(indent=2, ensure_ascii=False)
client.register_model( client.register_model(
model_type=model_type, model_type=model_type,
model=model_definition, model=model_definition,
@ -262,4 +264,3 @@ if cols[2].button("注册为自定义模型"):
st.rerun() st.rerun()
else: else:
st.error("必须输入存在的绝对路径") st.error("必须输入存在的绝对路径")