mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-01-19 13:23:16 +08:00
优化知识库相关功能 (#4153)
- 新功能
- pypi 包新增 chatchat-kb 命令脚本,对应 init_database.py 功能
- 开发者
- _model_config.py 中默认包含 xinference 配置项
- 所有涉及向量库的操作,前置检查当前 Embed 模型是否可用
- /knowledge_base/create_knowledge_base 接口增加 kb_info 参数
- /knowledge_base/list_files 接口返回所有数据库字段,而非文件名称列表
- 修正 xinference 模型管理脚本
This commit is contained in:
parent
8994b25a77
commit
a5b203170b
@ -29,6 +29,11 @@ vim model_providers.yaml
|
||||
>
|
||||
> 详细配置请参考[README.md](../model-providers/README.md)
|
||||
|
||||
- 初始化知识库
|
||||
```shell
|
||||
chatchat-kb -r
|
||||
```
|
||||
|
||||
- 启动服务
|
||||
```shell
|
||||
chatchat -a
|
||||
|
||||
@ -118,6 +118,25 @@ MODEL_PLATFORMS = [
|
||||
"tts_models": [],
|
||||
},
|
||||
|
||||
{
|
||||
"platform_name": "xinference",
|
||||
"platform_type": "xinference",
|
||||
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||
"api_key": "EMPTY",
|
||||
"api_concurrencies": 5,
|
||||
"llm_models": [
|
||||
"glm-4",
|
||||
"qwen2-instruct",
|
||||
"qwen1.5-chat",
|
||||
],
|
||||
"embed_models": [
|
||||
"bge-large-zh-v1.5",
|
||||
],
|
||||
"image_models": [],
|
||||
"reranking_models": [],
|
||||
"speech2text_models": [],
|
||||
"tts_models": [],
|
||||
},
|
||||
|
||||
]
|
||||
|
||||
|
||||
@ -20,16 +20,16 @@
|
||||
|
||||
xinference:
|
||||
model_credential:
|
||||
- model: 'chatglm3-6b'
|
||||
- model: 'glm-4'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'chatglm3-6b'
|
||||
- model: 'Qwen1.5-14B-Chat'
|
||||
model_uid: 'glm-4'
|
||||
- model: 'qwen1.5-chat'
|
||||
model_type: 'llm'
|
||||
model_credentials:
|
||||
server_url: 'http://127.0.0.1:9997/'
|
||||
model_uid: 'Qwen1.5-14B-Chat'
|
||||
model_uid: 'qwen1.5-chat'
|
||||
- model: 'bge-large-zh-v1.5'
|
||||
model_type: 'embeddings'
|
||||
model_credentials:
|
||||
@ -46,4 +46,4 @@ xinference:
|
||||
# model_type: 'llm'
|
||||
# model_credentials:
|
||||
# base_url: 'http://172.21.192.1:11434'
|
||||
# mode: 'completion'
|
||||
# mode: 'completion'
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
|
||||
from datetime import datetime
|
||||
import multiprocessing as mp
|
||||
from typing import Dict
|
||||
|
||||
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||
folder2db, prune_db_docs, prune_folder_files)
|
||||
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS
|
||||
import multiprocessing as mp
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS, logger
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def run_init_model_provider(
|
||||
@ -34,7 +33,7 @@ def run_init_model_provider(
|
||||
provider_port=provider_port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
|
||||
@ -186,3 +185,7 @@ if __name__ == "__main__":
|
||||
|
||||
for p in processes.values():
|
||||
logger.info("Process status: %s", p)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -14,6 +14,7 @@ def list_kbs():
|
||||
|
||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
vector_store_type: str = Body("faiss"),
|
||||
kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"),
|
||||
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||
) -> BaseResponse:
|
||||
# Create selected knowledge base
|
||||
@ -26,7 +27,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
if kb is not None:
|
||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info)
|
||||
try:
|
||||
kb.create_kb()
|
||||
except Exception as e:
|
||||
|
||||
@ -62,7 +62,7 @@ class CachePool:
|
||||
self._cache_num = cache_num
|
||||
self._cache = OrderedDict()
|
||||
self.atomic = threading.RLock()
|
||||
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
return list(self._cache.keys())
|
||||
|
||||
|
||||
@ -95,6 +95,7 @@ class KBFaissPool(_FaissPool):
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
locked = True
|
||||
vector_name = vector_name or embed_model
|
||||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||
try:
|
||||
@ -103,6 +104,7 @@ class KBFaissPool(_FaissPool):
|
||||
self.set((kb_name, vector_name), item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
locked = False
|
||||
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
|
||||
vs_path = get_vs_path(kb_name, vector_name)
|
||||
|
||||
@ -121,8 +123,10 @@ class KBFaissPool(_FaissPool):
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
locked = False
|
||||
except Exception as e:
|
||||
self.atomic.release()
|
||||
if locked: # we don't know exception raised before or after atomic.release
|
||||
self.atomic.release()
|
||||
logger.error(e)
|
||||
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
|
||||
return self.get((kb_name, vector_name))
|
||||
|
||||
@ -1,21 +1,23 @@
|
||||
import json
|
||||
import os
|
||||
import urllib
|
||||
from typing import List, Dict
|
||||
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from langchain.docstore.document import Document
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||
logger, log_verbose, )
|
||||
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
||||
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
|
||||
files2docs_in_thread, KnowledgeFile)
|
||||
from fastapi.responses import FileResponse
|
||||
from sse_starlette import EventSourceResponse
|
||||
import json
|
||||
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from langchain.docstore.document import Document
|
||||
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, get_kb_file_details
|
||||
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from typing import List, Dict
|
||||
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool, check_embed_model
|
||||
|
||||
|
||||
def search_docs(
|
||||
@ -55,8 +57,8 @@ def list_files(
|
||||
if kb is None:
|
||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||
else:
|
||||
all_doc_names = kb.list_files()
|
||||
return ListResponse(data=all_doc_names)
|
||||
all_docs = get_kb_file_details(knowledge_base_name)
|
||||
return ListResponse(data=all_docs)
|
||||
|
||||
|
||||
def _save_files_in_thread(files: List[UploadFile],
|
||||
@ -352,38 +354,42 @@ def recreate_vector_store(
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
if kb.exists():
|
||||
kb.clear_vs()
|
||||
kb.create_kb()
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
kb_files = [(file, knowledge_base_name) for file in files]
|
||||
i = 0
|
||||
for status, result in files2docs_in_thread(kb_files,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance):
|
||||
if status:
|
||||
kb_name, file_name, docs = result
|
||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||
kb_file.splited_docs = docs
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
kb_name, file_name, error = result
|
||||
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
error_msg = f"could not recreate vector store because failed to access embed model."
|
||||
if not kb.check_embed_model(error_msg):
|
||||
yield {"code": 404, "msg": error_msg}
|
||||
else:
|
||||
if kb.exists():
|
||||
kb.clear_vs()
|
||||
kb.create_kb()
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
kb_files = [(file, knowledge_base_name) for file in files]
|
||||
i = 0
|
||||
for status, result in files2docs_in_thread(kb_files,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance):
|
||||
if status:
|
||||
kb_name, file_name, docs = result
|
||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||
kb_file.splited_docs = docs
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
kb_name, file_name, error = result
|
||||
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
return EventSourceResponse(output())
|
||||
|
||||
@ -5,6 +5,11 @@ import os
|
||||
from pathlib import Path
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from typing import List, Union, Dict, Optional, Tuple
|
||||
|
||||
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
DEFAULT_EMBEDDING_MODEL, KB_INFO, logger)
|
||||
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
|
||||
from chatchat.server.db.repository.knowledge_base_repository import (
|
||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||
load_kb_from_db, get_kb_detail,
|
||||
@ -14,18 +19,12 @@ from chatchat.server.db.repository.knowledge_file_repository import (
|
||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
||||
list_docs_from_db,
|
||||
)
|
||||
|
||||
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
DEFAULT_EMBEDDING_MODEL, KB_INFO)
|
||||
from chatchat.server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, KnowledgeFile,
|
||||
list_kbs_from_folder, list_files_from_folder,
|
||||
)
|
||||
|
||||
from typing import List, Union, Dict, Optional, Tuple
|
||||
|
||||
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
|
||||
from chatchat.server.utils import check_embed_model as _check_embed_model
|
||||
|
||||
class SupportedVSType:
|
||||
FAISS = 'faiss'
|
||||
@ -41,10 +40,11 @@ class KBService(ABC):
|
||||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
kb_info: str = None,
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||
self.kb_info = kb_info or KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||
self.embed_model = embed_model
|
||||
self.kb_path = get_kb_path(self.kb_name)
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
@ -59,6 +59,13 @@ class KBService(ABC):
|
||||
'''
|
||||
pass
|
||||
|
||||
def check_embed_model(self, error_msg: str) -> bool:
|
||||
if not _check_embed_model(self.embed_model):
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def create_kb(self):
|
||||
"""
|
||||
创建知识库
|
||||
@ -93,6 +100,9 @@ class KBService(ABC):
|
||||
向知识库添加文件
|
||||
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
if not self.check_embed_model(f"could not add docs because failed to access embed model."):
|
||||
return False
|
||||
|
||||
if docs:
|
||||
custom_docs = True
|
||||
else:
|
||||
@ -143,6 +153,9 @@ class KBService(ABC):
|
||||
使用content中的文件更新向量库
|
||||
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
|
||||
return False
|
||||
|
||||
if os.path.exists(kb_file.filepath):
|
||||
self.delete_doc(kb_file, **kwargs)
|
||||
return self.add_doc(kb_file, docs=docs, **kwargs)
|
||||
@ -162,6 +175,8 @@ class KBService(ABC):
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
) ->List[Document]:
|
||||
if not self.check_embed_model(f"could not search docs because failed to access embed model."):
|
||||
return []
|
||||
docs = self.do_search(query, top_k, score_threshold)
|
||||
return docs
|
||||
|
||||
@ -176,6 +191,9 @@ class KBService(ABC):
|
||||
传入参数为: {doc_id: Document, ...}
|
||||
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
|
||||
'''
|
||||
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
|
||||
return False
|
||||
|
||||
self.del_doc_by_ids(list(docs.keys()))
|
||||
docs = []
|
||||
ids = []
|
||||
@ -282,31 +300,32 @@ class KBServiceFactory:
|
||||
def get_service(kb_name: str,
|
||||
vector_store_type: Union[str, SupportedVSType],
|
||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||
kb_info: str = None,
|
||||
) -> KBService:
|
||||
if isinstance(vector_store_type, str):
|
||||
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||
params = {"knowledge_base_name": kb_name, "embed_model": embed_model, "kb_info": kb_info}
|
||||
if SupportedVSType.FAISS == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
return FaissKBService(kb_name, embed_model=embed_model)
|
||||
return FaissKBService(**params)
|
||||
elif SupportedVSType.PG == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
||||
return PGKBService(kb_name, embed_model=embed_model)
|
||||
return PGKBService(**params)
|
||||
elif SupportedVSType.MILVUS == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name, embed_model=embed_model)
|
||||
return MilvusKBService(**params)
|
||||
elif SupportedVSType.ZILLIZ == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
|
||||
return ZillizKBService(kb_name, embed_model=embed_model)
|
||||
return ZillizKBService(**params)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name,
|
||||
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||
return MilvusKBService(**params) # other milvus parameters are set in model_config.kbs_config
|
||||
elif SupportedVSType.ES == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService
|
||||
return ESKBService(kb_name, embed_model=embed_model)
|
||||
return ESKBService(**params)
|
||||
elif SupportedVSType.CHROMADB == vector_store_type:
|
||||
from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
|
||||
return ChromaKBService(kb_name, embed_model=embed_model)
|
||||
return ChromaKBService(**params)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||
from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||
return DefaultKBService(kb_name)
|
||||
|
||||
@ -42,55 +42,59 @@ def recreate_summary_vector_store(
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
# 重新创建知识库
|
||||
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||
kb_summary.drop_kb_summary()
|
||||
kb_summary.create_kb_summary()
|
||||
error_msg = f"could not recreate summary vector store because failed to access embed model."
|
||||
if not kb.check_embed_model(error_msg):
|
||||
yield {"code": 404, "msg": error_msg}
|
||||
else:
|
||||
# 重新创建知识库
|
||||
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||
kb_summary.drop_kb_summary()
|
||||
kb_summary.create_kb_summary()
|
||||
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
local_wrap=True,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
local_wrap=True,
|
||||
)
|
||||
# 文本摘要适配器
|
||||
summary = SummaryAdapter.form_summary(llm=llm,
|
||||
reduce_llm=reduce_llm,
|
||||
overlap_size=OVERLAP_SIZE)
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
local_wrap=True,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
local_wrap=True,
|
||||
)
|
||||
# 文本摘要适配器
|
||||
summary = SummaryAdapter.form_summary(llm=llm,
|
||||
reduce_llm=reduce_llm,
|
||||
overlap_size=OVERLAP_SIZE)
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
|
||||
i = 0
|
||||
for i, file_name in enumerate(files):
|
||||
i = 0
|
||||
for i, file_name in enumerate(files):
|
||||
|
||||
doc_infos = kb.list_docs(file_name=file_name)
|
||||
docs = summary.summarize(file_description=file_description,
|
||||
docs=doc_infos)
|
||||
doc_infos = kb.list_docs(file_name=file_name)
|
||||
docs = summary.summarize(file_description=file_description,
|
||||
docs=doc_infos)
|
||||
|
||||
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||
if status_kb_summary:
|
||||
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||
if status_kb_summary:
|
||||
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
|
||||
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
|
||||
return EventSourceResponse(output())
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from chatchat.configs import (
|
||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE
|
||||
CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose
|
||||
)
|
||||
from chatchat.server.knowledge_base.utils import (
|
||||
get_file_path, list_kbs_from_folder,
|
||||
|
||||
@ -244,6 +244,16 @@ def get_Embeddings(
|
||||
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
|
||||
|
||||
|
||||
def check_embed_model(embed_model: str=DEFAULT_EMBEDDING_MODEL) -> bool:
|
||||
embeddings = get_Embeddings(embed_model=embed_model)
|
||||
try:
|
||||
embeddings.embed_query("this is a test")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"failed to access embed model '{embed_model}': {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_OpenAIClient(
|
||||
platform_name: str = None,
|
||||
model_name: str = None,
|
||||
@ -696,7 +706,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
|
||||
'''
|
||||
创建一个临时目录,返回(路径,文件夹名称)
|
||||
'''
|
||||
from chatchat.configs.basic_config import BASE_TEMP_DIR
|
||||
from chatchat.configs import BASE_TEMP_DIR
|
||||
import uuid
|
||||
|
||||
if id is not None: # 如果指定的临时目录已存在,直接返回
|
||||
|
||||
@ -10,6 +10,7 @@ packages = [
|
||||
|
||||
[tool.poetry.scripts]
|
||||
chatchat = 'chatchat.startup:main'
|
||||
chatchat-kb = 'chatchat.init_database:main'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<3.12,!=3.9.7"
|
||||
|
||||
@ -133,9 +133,11 @@ model_names = list(regs[model_type].keys())
|
||||
model_name = cols[1].selectbox("模型名称:", model_names)
|
||||
|
||||
cur_reg = regs[model_type][model_name]["reg"]
|
||||
model_format = None
|
||||
model_quant = None
|
||||
|
||||
if model_type == "LLM":
|
||||
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
|
||||
cur_family = xf_llm.LLMFamilyV1.model_validate(cur_reg)
|
||||
cur_spec = None
|
||||
model_formats = []
|
||||
for spec in cur_reg["model_specs"]:
|
||||
@ -217,7 +219,7 @@ cols = st.columns(3)
|
||||
|
||||
if cols[0].button("设置模型缓存"):
|
||||
if os.path.isabs(local_path) and os.path.isdir(local_path):
|
||||
cur_spec.model_uri = local_path
|
||||
cur_spec.__dict__["model_uri"] = local_path # embedding spec has no attribute model_uri
|
||||
if os.path.isdir(cache_dir):
|
||||
os.rmdir(cache_dir)
|
||||
if model_type == "LLM":
|
||||
@ -250,10 +252,10 @@ if cols[2].button("注册为自定义模型"):
|
||||
if model_type == "LLM":
|
||||
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
|
||||
cur_family.model_family = "other"
|
||||
model_definition = cur_family.json(indent=2, ensure_ascii=False)
|
||||
model_definition = cur_family.model_dump_json(indent=2, ensure_ascii=False)
|
||||
else:
|
||||
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
|
||||
model_definition = cur_spec.json(indent=2, ensure_ascii=False)
|
||||
model_definition = cur_spec.model_dump_json(indent=2, ensure_ascii=False)
|
||||
client.register_model(
|
||||
model_type=model_type,
|
||||
model=model_definition,
|
||||
@ -262,4 +264,3 @@ if cols[2].button("注册为自定义模型"):
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("必须输入存在的绝对路径")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user