From a5b203170b481bf235fbdcdd517ee36b2fdfa294 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Sat, 8 Jun 2024 14:34:50 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E5=8A=9F=E8=83=BD=20(#4153)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新功能 - pypi 包新增 chatchat-kb 命令脚本,对应 init_database.py 功能 - 开发者 - _model_config.py 中默认包含 xinference 配置项 - 所有涉及向量库的操作,前置检查当前 Embed 模型是否可用 - /knowledge_base/create_knowledge_base 接口增加 kb_info 参数 - /knowledge_base/list_files 接口返回所有数据库字段,而非文件名称列表 - 修正 xinference 模型管理脚本 --- libs/chatchat-server/README.md | 5 + .../chatchat/configs/_model_config.py | 19 ++++ .../chatchat/configs/model_providers.yaml | 10 +- .../chatchat-server/chatchat/init_database.py | 15 +-- .../chatchat/server/knowledge_base/kb_api.py | 3 +- .../server/knowledge_base/kb_cache/base.py | 2 +- .../knowledge_base/kb_cache/faiss_cache.py | 6 +- .../server/knowledge_base/kb_doc_api.py | 92 ++++++++++--------- .../server/knowledge_base/kb_service/base.py | 51 ++++++---- .../server/knowledge_base/kb_summary_api.py | 92 ++++++++++--------- .../chatchat/server/knowledge_base/migrate.py | 2 +- libs/chatchat-server/chatchat/server/utils.py | 12 ++- libs/chatchat-server/pyproject.toml | 1 + tools/model_loaders/xinference_manager.py | 11 ++- 14 files changed, 197 insertions(+), 124 deletions(-) diff --git a/libs/chatchat-server/README.md b/libs/chatchat-server/README.md index 4e0786c8..f2ce0940 100644 --- a/libs/chatchat-server/README.md +++ b/libs/chatchat-server/README.md @@ -29,6 +29,11 @@ vim model_providers.yaml > > 详细配置请参考[README.md](../model-providers/README.md) +- 初始化知识库 +```shell +chatchat-kb -r +``` + - 启动服务 ```shell chatchat -a diff --git a/libs/chatchat-server/chatchat/configs/_model_config.py b/libs/chatchat-server/chatchat/configs/_model_config.py index 856981e9..2762628a 100644 --- a/libs/chatchat-server/chatchat/configs/_model_config.py +++ b/libs/chatchat-server/chatchat/configs/_model_config.py @@ -118,6 +118,25 @@ MODEL_PLATFORMS = [ "tts_models": [], }, + { + "platform_name": "xinference", + "platform_type": "xinference", + "api_base_url": "http://127.0.0.1:9997/v1", + "api_key": "EMPTY", + "api_concurrencies": 5, + "llm_models": [ + "glm-4", + "qwen2-instruct", + "qwen1.5-chat", + ], + "embed_models": [ + "bge-large-zh-v1.5", + ], + "image_models": [], + "reranking_models": [], + "speech2text_models": [], + "tts_models": [], + }, ] diff --git a/libs/chatchat-server/chatchat/configs/model_providers.yaml b/libs/chatchat-server/chatchat/configs/model_providers.yaml index b47142fc..4032a21e 100644 --- a/libs/chatchat-server/chatchat/configs/model_providers.yaml +++ b/libs/chatchat-server/chatchat/configs/model_providers.yaml @@ -20,16 +20,16 @@ xinference: model_credential: - - model: 'chatglm3-6b' + - model: 'glm-4' model_type: 'llm' model_credentials: server_url: 'http://127.0.0.1:9997/' - model_uid: 'chatglm3-6b' - - model: 'Qwen1.5-14B-Chat' + model_uid: 'glm-4' + - model: 'qwen1.5-chat' model_type: 'llm' model_credentials: server_url: 'http://127.0.0.1:9997/' - model_uid: 'Qwen1.5-14B-Chat' + model_uid: 'qwen1.5-chat' - model: 'bge-large-zh-v1.5' model_type: 'embeddings' model_credentials: @@ -46,4 +46,4 @@ xinference: # model_type: 'llm' # model_credentials: # base_url: 'http://172.21.192.1:11434' -# mode: 'completion' \ No newline at end of file +# mode: 'completion' diff --git a/libs/chatchat-server/chatchat/init_database.py b/libs/chatchat-server/chatchat/init_database.py index 7a1baec7..d08ec87a 100644 --- a/libs/chatchat-server/chatchat/init_database.py +++ b/libs/chatchat-server/chatchat/init_database.py @@ -1,13 +1,12 @@ # Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作 +from datetime import datetime +import multiprocessing as mp from typing import Dict + from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db, folder2db, prune_db_docs, prune_folder_files) -from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS -import multiprocessing as mp -import logging -logger = logging.getLogger(__name__) +from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS, logger -from datetime import datetime def run_init_model_provider( @@ -34,7 +33,7 @@ def run_init_model_provider( provider_port=provider_port) -if __name__ == "__main__": +def main(): import argparse parser = argparse.ArgumentParser(description="please specify only one operate method once time.") @@ -186,3 +185,7 @@ if __name__ == "__main__": for p in processes.values(): logger.info("Process status: %s", p) + + +if __name__ == "__main__": + main() diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py index ea3d6089..0ccc0704 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_api.py @@ -14,6 +14,7 @@ def list_kbs(): def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), vector_store_type: str = Body("faiss"), + kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"), embed_model: str = Body(DEFAULT_EMBEDDING_MODEL), ) -> BaseResponse: # Create selected knowledge base @@ -26,7 +27,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), if kb is not None: return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}") - kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model) + kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info) try: kb.create_kb() except Exception as e: diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py index a6d3f425..a11e0054 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/base.py @@ -62,7 +62,7 @@ class CachePool: self._cache_num = cache_num self._cache = OrderedDict() self.atomic = threading.RLock() - + def keys(self) -> List[str]: return list(self._cache.keys()) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py index d37bd29d..ec8adba2 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_cache/faiss_cache.py @@ -95,6 +95,7 @@ class KBFaissPool(_FaissPool): embed_model: str = DEFAULT_EMBEDDING_MODEL, ) -> ThreadSafeFaiss: self.atomic.acquire() + locked = True vector_name = vector_name or embed_model cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 try: @@ -103,6 +104,7 @@ class KBFaissPool(_FaissPool): self.set((kb_name, vector_name), item) with item.acquire(msg="初始化"): self.atomic.release() + locked = False logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.") vs_path = get_vs_path(kb_name, vector_name) @@ -121,8 +123,10 @@ class KBFaissPool(_FaissPool): item.finish_loading() else: self.atomic.release() + locked = False except Exception as e: - self.atomic.release() + if locked: # we don't know exception raised before or after atomic.release + self.atomic.release() logger.error(e) raise RuntimeError(f"向量库 {kb_name} 加载失败。") return self.get((kb_name, vector_name)) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py index 2763028f..0e92c091 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_doc_api.py @@ -1,21 +1,23 @@ +import json import os import urllib +from typing import List, Dict + from fastapi import File, Form, Body, Query, UploadFile +from fastapi.responses import FileResponse +from langchain.docstore.document import Document +from sse_starlette import EventSourceResponse + from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, logger, log_verbose, ) -from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool +from chatchat.server.db.repository.knowledge_file_repository import get_file_detail from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path, files2docs_in_thread, KnowledgeFile) -from fastapi.responses import FileResponse -from sse_starlette import EventSourceResponse -import json -from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory -from chatchat.server.db.repository.knowledge_file_repository import get_file_detail -from langchain.docstore.document import Document +from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, get_kb_file_details from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId -from typing import List, Dict +from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool, check_embed_model def search_docs( @@ -55,8 +57,8 @@ def list_files( if kb is None: return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: - all_doc_names = kb.list_files() - return ListResponse(data=all_doc_names) + all_docs = get_kb_file_details(knowledge_base_name) + return ListResponse(data=all_docs) def _save_files_in_thread(files: List[UploadFile], @@ -352,38 +354,42 @@ def recreate_vector_store( if not kb.exists() and not allow_empty_kb: yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} else: - if kb.exists(): - kb.clear_vs() - kb.create_kb() - files = list_files_from_folder(knowledge_base_name) - kb_files = [(file, knowledge_base_name) for file in files] - i = 0 - for status, result in files2docs_in_thread(kb_files, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - zh_title_enhance=zh_title_enhance): - if status: - kb_name, file_name, docs = result - kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name) - kb_file.splited_docs = docs - yield json.dumps({ - "code": 200, - "msg": f"({i + 1} / {len(files)}): {file_name}", - "total": len(files), - "finished": i + 1, - "doc": file_name, - }, ensure_ascii=False) - kb.add_doc(kb_file, not_refresh_vs_cache=True) - else: - kb_name, file_name, error = result - msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" - logger.error(msg) - yield json.dumps({ - "code": 500, - "msg": msg, - }) - i += 1 - if not not_refresh_vs_cache: - kb.save_vector_store() + error_msg = f"could not recreate vector store because failed to access embed model." + if not kb.check_embed_model(error_msg): + yield {"code": 404, "msg": error_msg} + else: + if kb.exists(): + kb.clear_vs() + kb.create_kb() + files = list_files_from_folder(knowledge_base_name) + kb_files = [(file, knowledge_base_name) for file in files] + i = 0 + for status, result in files2docs_in_thread(kb_files, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + zh_title_enhance=zh_title_enhance): + if status: + kb_name, file_name, docs = result + kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name) + kb_file.splited_docs = docs + yield json.dumps({ + "code": 200, + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), + "finished": i + 1, + "doc": file_name, + }, ensure_ascii=False) + kb.add_doc(kb_file, not_refresh_vs_cache=True) + else: + kb_name, file_name, error = result + msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。" + logger.error(msg) + yield json.dumps({ + "code": 500, + "msg": msg, + }) + i += 1 + if not not_refresh_vs_cache: + kb.save_vector_store() return EventSourceResponse(output()) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py index 7cdbf102..ac21863e 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py @@ -5,6 +5,11 @@ import os from pathlib import Path from langchain.docstore.document import Document +from typing import List, Union, Dict, Optional, Tuple + +from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + DEFAULT_EMBEDDING_MODEL, KB_INFO, logger) +from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema from chatchat.server.db.repository.knowledge_base_repository import ( add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db, get_kb_detail, @@ -14,18 +19,12 @@ from chatchat.server.db.repository.knowledge_file_repository import ( count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db, list_docs_from_db, ) - -from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - DEFAULT_EMBEDDING_MODEL, KB_INFO) from chatchat.server.knowledge_base.utils import ( get_kb_path, get_doc_path, KnowledgeFile, list_kbs_from_folder, list_files_from_folder, ) - -from typing import List, Union, Dict, Optional, Tuple - from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId -from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema +from chatchat.server.utils import check_embed_model as _check_embed_model class SupportedVSType: FAISS = 'faiss' @@ -41,10 +40,11 @@ class KBService(ABC): def __init__(self, knowledge_base_name: str, + kb_info: str = None, embed_model: str = DEFAULT_EMBEDDING_MODEL, ): self.kb_name = knowledge_base_name - self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") + self.kb_info = kb_info or KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") self.embed_model = embed_model self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) @@ -59,6 +59,13 @@ class KBService(ABC): ''' pass + def check_embed_model(self, error_msg: str) -> bool: + if not _check_embed_model(self.embed_model): + logger.error(error_msg, exc_info=True) + return False + else: + return True + def create_kb(self): """ 创建知识库 @@ -93,6 +100,9 @@ class KBService(ABC): 向知识库添加文件 如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True """ + if not self.check_embed_model(f"could not add docs because failed to access embed model."): + return False + if docs: custom_docs = True else: @@ -143,6 +153,9 @@ class KBService(ABC): 使用content中的文件更新向量库 如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True """ + if not self.check_embed_model(f"could not update docs because failed to access embed model."): + return False + if os.path.exists(kb_file.filepath): self.delete_doc(kb_file, **kwargs) return self.add_doc(kb_file, docs=docs, **kwargs) @@ -162,6 +175,8 @@ class KBService(ABC): top_k: int = VECTOR_SEARCH_TOP_K, score_threshold: float = SCORE_THRESHOLD, ) ->List[Document]: + if not self.check_embed_model(f"could not search docs because failed to access embed model."): + return [] docs = self.do_search(query, top_k, score_threshold) return docs @@ -176,6 +191,9 @@ class KBService(ABC): 传入参数为: {doc_id: Document, ...} 如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档 ''' + if not self.check_embed_model(f"could not update docs because failed to access embed model."): + return False + self.del_doc_by_ids(list(docs.keys())) docs = [] ids = [] @@ -282,31 +300,32 @@ class KBServiceFactory: def get_service(kb_name: str, vector_store_type: Union[str, SupportedVSType], embed_model: str = DEFAULT_EMBEDDING_MODEL, + kb_info: str = None, ) -> KBService: if isinstance(vector_store_type, str): vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) + params = {"knowledge_base_name": kb_name, "embed_model": embed_model, "kb_info": kb_info} if SupportedVSType.FAISS == vector_store_type: from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService - return FaissKBService(kb_name, embed_model=embed_model) + return FaissKBService(**params) elif SupportedVSType.PG == vector_store_type: from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService - return PGKBService(kb_name, embed_model=embed_model) + return PGKBService(**params) elif SupportedVSType.MILVUS == vector_store_type: from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(kb_name, embed_model=embed_model) + return MilvusKBService(**params) elif SupportedVSType.ZILLIZ == vector_store_type: from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService - return ZillizKBService(kb_name, embed_model=embed_model) + return ZillizKBService(**params) elif SupportedVSType.DEFAULT == vector_store_type: from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService - return MilvusKBService(kb_name, - embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config + return MilvusKBService(**params) # other milvus parameters are set in model_config.kbs_config elif SupportedVSType.ES == vector_store_type: from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService - return ESKBService(kb_name, embed_model=embed_model) + return ESKBService(**params) elif SupportedVSType.CHROMADB == vector_store_type: from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService - return ChromaKBService(kb_name, embed_model=embed_model) + return ChromaKBService(**params) elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService return DefaultKBService(kb_name) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py index f1b83fa8..fd18fe17 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_summary_api.py @@ -42,55 +42,59 @@ def recreate_summary_vector_store( if not kb.exists() and not allow_empty_kb: yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"} else: - # 重新创建知识库 - kb_summary = KBSummaryService(knowledge_base_name, embed_model) - kb_summary.drop_kb_summary() - kb_summary.create_kb_summary() + error_msg = f"could not recreate summary vector store because failed to access embed model." + if not kb.check_embed_model(error_msg): + yield {"code": 404, "msg": error_msg} + else: + # 重新创建知识库 + kb_summary = KBSummaryService(knowledge_base_name, embed_model) + kb_summary.drop_kb_summary() + kb_summary.create_kb_summary() - llm = get_ChatOpenAI( - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - local_wrap=True, - ) - reduce_llm = get_ChatOpenAI( - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - local_wrap=True, - ) - # 文本摘要适配器 - summary = SummaryAdapter.form_summary(llm=llm, - reduce_llm=reduce_llm, - overlap_size=OVERLAP_SIZE) - files = list_files_from_folder(knowledge_base_name) + llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + local_wrap=True, + ) + reduce_llm = get_ChatOpenAI( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + local_wrap=True, + ) + # 文本摘要适配器 + summary = SummaryAdapter.form_summary(llm=llm, + reduce_llm=reduce_llm, + overlap_size=OVERLAP_SIZE) + files = list_files_from_folder(knowledge_base_name) - i = 0 - for i, file_name in enumerate(files): + i = 0 + for i, file_name in enumerate(files): - doc_infos = kb.list_docs(file_name=file_name) - docs = summary.summarize(file_description=file_description, - docs=doc_infos) + doc_infos = kb.list_docs(file_name=file_name) + docs = summary.summarize(file_description=file_description, + docs=doc_infos) - status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) - if status_kb_summary: - logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成") - yield json.dumps({ - "code": 200, - "msg": f"({i + 1} / {len(files)}): {file_name}", - "total": len(files), - "finished": i + 1, - "doc": file_name, - }, ensure_ascii=False) - else: + status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs) + if status_kb_summary: + logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成") + yield json.dumps({ + "code": 200, + "msg": f"({i + 1} / {len(files)}): {file_name}", + "total": len(files), + "finished": i + 1, + "doc": file_name, + }, ensure_ascii=False) + else: - msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" - logger.error(msg) - yield json.dumps({ - "code": 500, - "msg": msg, - }) - i += 1 + msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。" + logger.error(msg) + yield json.dumps({ + "code": 500, + "msg": msg, + }) + i += 1 return EventSourceResponse(output()) diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py b/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py index f4d35c40..01811691 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/migrate.py @@ -1,6 +1,6 @@ from chatchat.configs import ( DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, - CHUNK_SIZE, OVERLAP_SIZE + CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose ) from chatchat.server.knowledge_base.utils import ( get_file_path, list_kbs_from_folder, diff --git a/libs/chatchat-server/chatchat/server/utils.py b/libs/chatchat-server/chatchat/server/utils.py index 314ef0cc..fdf1a9ab 100644 --- a/libs/chatchat-server/chatchat/server/utils.py +++ b/libs/chatchat-server/chatchat/server/utils.py @@ -244,6 +244,16 @@ def get_Embeddings( logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True) +def check_embed_model(embed_model: str=DEFAULT_EMBEDDING_MODEL) -> bool: + embeddings = get_Embeddings(embed_model=embed_model) + try: + embeddings.embed_query("this is a test") + return True + except Exception as e: + logger.error(f"failed to access embed model '{embed_model}': {e}", exc_info=True) + return False + + def get_OpenAIClient( platform_name: str = None, model_name: str = None, @@ -696,7 +706,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]: ''' 创建一个临时目录,返回(路径,文件夹名称) ''' - from chatchat.configs.basic_config import BASE_TEMP_DIR + from chatchat.configs import BASE_TEMP_DIR import uuid if id is not None: # 如果指定的临时目录已存在,直接返回 diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index 3bb9870b..4d1846b2 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -10,6 +10,7 @@ packages = [ [tool.poetry.scripts] chatchat = 'chatchat.startup:main' +chatchat-kb = 'chatchat.init_database:main' [tool.poetry.dependencies] python = ">=3.8.1,<3.12,!=3.9.7" diff --git a/tools/model_loaders/xinference_manager.py b/tools/model_loaders/xinference_manager.py index c2a86cc0..650d7cc2 100644 --- a/tools/model_loaders/xinference_manager.py +++ b/tools/model_loaders/xinference_manager.py @@ -133,9 +133,11 @@ model_names = list(regs[model_type].keys()) model_name = cols[1].selectbox("模型名称:", model_names) cur_reg = regs[model_type][model_name]["reg"] +model_format = None +model_quant = None if model_type == "LLM": - cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg) + cur_family = xf_llm.LLMFamilyV1.model_validate(cur_reg) cur_spec = None model_formats = [] for spec in cur_reg["model_specs"]: @@ -217,7 +219,7 @@ cols = st.columns(3) if cols[0].button("设置模型缓存"): if os.path.isabs(local_path) and os.path.isdir(local_path): - cur_spec.model_uri = local_path + cur_spec.__dict__["model_uri"] = local_path # embedding spec has no attribute model_uri if os.path.isdir(cache_dir): os.rmdir(cache_dir) if model_type == "LLM": @@ -250,10 +252,10 @@ if cols[2].button("注册为自定义模型"): if model_type == "LLM": cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}" cur_family.model_family = "other" - model_definition = cur_family.json(indent=2, ensure_ascii=False) + model_definition = cur_family.model_dump_json(indent=2, ensure_ascii=False) else: cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}" - model_definition = cur_spec.json(indent=2, ensure_ascii=False) + model_definition = cur_spec.model_dump_json(indent=2, ensure_ascii=False) client.register_model( model_type=model_type, model=model_definition, @@ -262,4 +264,3 @@ if cols[2].button("注册为自定义模型"): st.rerun() else: st.error("必须输入存在的绝对路径") -