mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-06 06:49:48 +08:00
优化知识库相关功能 (#4153)
- 新功能
- pypi 包新增 chatchat-kb 命令脚本,对应 init_database.py 功能
- 开发者
- _model_config.py 中默认包含 xinference 配置项
- 所有涉及向量库的操作,前置检查当前 Embed 模型是否可用
- /knowledge_base/create_knowledge_base 接口增加 kb_info 参数
- /knowledge_base/list_files 接口返回所有数据库字段,而非文件名称列表
- 修正 xinference 模型管理脚本
This commit is contained in:
parent
8994b25a77
commit
a5b203170b
@ -29,6 +29,11 @@ vim model_providers.yaml
|
|||||||
>
|
>
|
||||||
> 详细配置请参考[README.md](../model-providers/README.md)
|
> 详细配置请参考[README.md](../model-providers/README.md)
|
||||||
|
|
||||||
|
- 初始化知识库
|
||||||
|
```shell
|
||||||
|
chatchat-kb -r
|
||||||
|
```
|
||||||
|
|
||||||
- 启动服务
|
- 启动服务
|
||||||
```shell
|
```shell
|
||||||
chatchat -a
|
chatchat -a
|
||||||
|
|||||||
@ -118,6 +118,25 @@ MODEL_PLATFORMS = [
|
|||||||
"tts_models": [],
|
"tts_models": [],
|
||||||
},
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"platform_name": "xinference",
|
||||||
|
"platform_type": "xinference",
|
||||||
|
"api_base_url": "http://127.0.0.1:9997/v1",
|
||||||
|
"api_key": "EMPTY",
|
||||||
|
"api_concurrencies": 5,
|
||||||
|
"llm_models": [
|
||||||
|
"glm-4",
|
||||||
|
"qwen2-instruct",
|
||||||
|
"qwen1.5-chat",
|
||||||
|
],
|
||||||
|
"embed_models": [
|
||||||
|
"bge-large-zh-v1.5",
|
||||||
|
],
|
||||||
|
"image_models": [],
|
||||||
|
"reranking_models": [],
|
||||||
|
"speech2text_models": [],
|
||||||
|
"tts_models": [],
|
||||||
|
},
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -20,16 +20,16 @@
|
|||||||
|
|
||||||
xinference:
|
xinference:
|
||||||
model_credential:
|
model_credential:
|
||||||
- model: 'chatglm3-6b'
|
- model: 'glm-4'
|
||||||
model_type: 'llm'
|
model_type: 'llm'
|
||||||
model_credentials:
|
model_credentials:
|
||||||
server_url: 'http://127.0.0.1:9997/'
|
server_url: 'http://127.0.0.1:9997/'
|
||||||
model_uid: 'chatglm3-6b'
|
model_uid: 'glm-4'
|
||||||
- model: 'Qwen1.5-14B-Chat'
|
- model: 'qwen1.5-chat'
|
||||||
model_type: 'llm'
|
model_type: 'llm'
|
||||||
model_credentials:
|
model_credentials:
|
||||||
server_url: 'http://127.0.0.1:9997/'
|
server_url: 'http://127.0.0.1:9997/'
|
||||||
model_uid: 'Qwen1.5-14B-Chat'
|
model_uid: 'qwen1.5-chat'
|
||||||
- model: 'bge-large-zh-v1.5'
|
- model: 'bge-large-zh-v1.5'
|
||||||
model_type: 'embeddings'
|
model_type: 'embeddings'
|
||||||
model_credentials:
|
model_credentials:
|
||||||
@ -46,4 +46,4 @@ xinference:
|
|||||||
# model_type: 'llm'
|
# model_type: 'llm'
|
||||||
# model_credentials:
|
# model_credentials:
|
||||||
# base_url: 'http://172.21.192.1:11434'
|
# base_url: 'http://172.21.192.1:11434'
|
||||||
# mode: 'completion'
|
# mode: 'completion'
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
|
# Description: 初始化数据库,包括创建表、导入数据、更新向量空间等操作
|
||||||
|
from datetime import datetime
|
||||||
|
import multiprocessing as mp
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
from chatchat.server.knowledge_base.migrate import (create_tables, reset_tables, import_from_db,
|
||||||
folder2db, prune_db_docs, prune_folder_files)
|
folder2db, prune_db_docs, prune_folder_files)
|
||||||
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS
|
from chatchat.configs import DEFAULT_EMBEDDING_MODEL, MODEL_PLATFORMS, logger
|
||||||
import multiprocessing as mp
|
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def run_init_model_provider(
|
def run_init_model_provider(
|
||||||
@ -34,7 +33,7 @@ def run_init_model_provider(
|
|||||||
provider_port=provider_port)
|
provider_port=provider_port)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
|
parser = argparse.ArgumentParser(description="please specify only one operate method once time.")
|
||||||
@ -186,3 +185,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for p in processes.values():
|
for p in processes.values():
|
||||||
logger.info("Process status: %s", p)
|
logger.info("Process status: %s", p)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@ -14,6 +14,7 @@ def list_kbs():
|
|||||||
|
|
||||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||||
vector_store_type: str = Body("faiss"),
|
vector_store_type: str = Body("faiss"),
|
||||||
|
kb_info: str = Body("", description="知识库内容简介,用于Agent选择知识库。"),
|
||||||
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
embed_model: str = Body(DEFAULT_EMBEDDING_MODEL),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
# Create selected knowledge base
|
# Create selected knowledge base
|
||||||
@ -26,7 +27,7 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
|||||||
if kb is not None:
|
if kb is not None:
|
||||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||||
|
|
||||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model, kb_info=kb_info)
|
||||||
try:
|
try:
|
||||||
kb.create_kb()
|
kb.create_kb()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -62,7 +62,7 @@ class CachePool:
|
|||||||
self._cache_num = cache_num
|
self._cache_num = cache_num
|
||||||
self._cache = OrderedDict()
|
self._cache = OrderedDict()
|
||||||
self.atomic = threading.RLock()
|
self.atomic = threading.RLock()
|
||||||
|
|
||||||
def keys(self) -> List[str]:
|
def keys(self) -> List[str]:
|
||||||
return list(self._cache.keys())
|
return list(self._cache.keys())
|
||||||
|
|
||||||
|
|||||||
@ -95,6 +95,7 @@ class KBFaissPool(_FaissPool):
|
|||||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
) -> ThreadSafeFaiss:
|
) -> ThreadSafeFaiss:
|
||||||
self.atomic.acquire()
|
self.atomic.acquire()
|
||||||
|
locked = True
|
||||||
vector_name = vector_name or embed_model
|
vector_name = vector_name or embed_model
|
||||||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||||
try:
|
try:
|
||||||
@ -103,6 +104,7 @@ class KBFaissPool(_FaissPool):
|
|||||||
self.set((kb_name, vector_name), item)
|
self.set((kb_name, vector_name), item)
|
||||||
with item.acquire(msg="初始化"):
|
with item.acquire(msg="初始化"):
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
|
locked = False
|
||||||
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
|
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
|
||||||
vs_path = get_vs_path(kb_name, vector_name)
|
vs_path = get_vs_path(kb_name, vector_name)
|
||||||
|
|
||||||
@ -121,8 +123,10 @@ class KBFaissPool(_FaissPool):
|
|||||||
item.finish_loading()
|
item.finish_loading()
|
||||||
else:
|
else:
|
||||||
self.atomic.release()
|
self.atomic.release()
|
||||||
|
locked = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.atomic.release()
|
if locked: # we don't know exception raised before or after atomic.release
|
||||||
|
self.atomic.release()
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
|
raise RuntimeError(f"向量库 {kb_name} 加载失败。")
|
||||||
return self.get((kb_name, vector_name))
|
return self.get((kb_name, vector_name))
|
||||||
|
|||||||
@ -1,21 +1,23 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
from fastapi import File, Form, Body, Query, UploadFile
|
from fastapi import File, Form, Body, Query, UploadFile
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
|
||||||
from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
|
from chatchat.configs import (DEFAULT_VS_TYPE, DEFAULT_EMBEDDING_MODEL,
|
||||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||||
logger, log_verbose, )
|
logger, log_verbose, )
|
||||||
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
|
||||||
from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
|
from chatchat.server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
|
||||||
files2docs_in_thread, KnowledgeFile)
|
files2docs_in_thread, KnowledgeFile)
|
||||||
from fastapi.responses import FileResponse
|
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory, get_kb_file_details
|
||||||
from sse_starlette import EventSourceResponse
|
|
||||||
import json
|
|
||||||
from chatchat.server.knowledge_base.kb_service.base import KBServiceFactory
|
|
||||||
from chatchat.server.db.repository.knowledge_file_repository import get_file_detail
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
from typing import List, Dict
|
from chatchat.server.utils import BaseResponse, ListResponse, run_in_thread_pool, check_embed_model
|
||||||
|
|
||||||
|
|
||||||
def search_docs(
|
def search_docs(
|
||||||
@ -55,8 +57,8 @@ def list_files(
|
|||||||
if kb is None:
|
if kb is None:
|
||||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||||
else:
|
else:
|
||||||
all_doc_names = kb.list_files()
|
all_docs = get_kb_file_details(knowledge_base_name)
|
||||||
return ListResponse(data=all_doc_names)
|
return ListResponse(data=all_docs)
|
||||||
|
|
||||||
|
|
||||||
def _save_files_in_thread(files: List[UploadFile],
|
def _save_files_in_thread(files: List[UploadFile],
|
||||||
@ -352,38 +354,42 @@ def recreate_vector_store(
|
|||||||
if not kb.exists() and not allow_empty_kb:
|
if not kb.exists() and not allow_empty_kb:
|
||||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||||
else:
|
else:
|
||||||
if kb.exists():
|
error_msg = f"could not recreate vector store because failed to access embed model."
|
||||||
kb.clear_vs()
|
if not kb.check_embed_model(error_msg):
|
||||||
kb.create_kb()
|
yield {"code": 404, "msg": error_msg}
|
||||||
files = list_files_from_folder(knowledge_base_name)
|
else:
|
||||||
kb_files = [(file, knowledge_base_name) for file in files]
|
if kb.exists():
|
||||||
i = 0
|
kb.clear_vs()
|
||||||
for status, result in files2docs_in_thread(kb_files,
|
kb.create_kb()
|
||||||
chunk_size=chunk_size,
|
files = list_files_from_folder(knowledge_base_name)
|
||||||
chunk_overlap=chunk_overlap,
|
kb_files = [(file, knowledge_base_name) for file in files]
|
||||||
zh_title_enhance=zh_title_enhance):
|
i = 0
|
||||||
if status:
|
for status, result in files2docs_in_thread(kb_files,
|
||||||
kb_name, file_name, docs = result
|
chunk_size=chunk_size,
|
||||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
chunk_overlap=chunk_overlap,
|
||||||
kb_file.splited_docs = docs
|
zh_title_enhance=zh_title_enhance):
|
||||||
yield json.dumps({
|
if status:
|
||||||
"code": 200,
|
kb_name, file_name, docs = result
|
||||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||||
"total": len(files),
|
kb_file.splited_docs = docs
|
||||||
"finished": i + 1,
|
yield json.dumps({
|
||||||
"doc": file_name,
|
"code": 200,
|
||||||
}, ensure_ascii=False)
|
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||||
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
"total": len(files),
|
||||||
else:
|
"finished": i + 1,
|
||||||
kb_name, file_name, error = result
|
"doc": file_name,
|
||||||
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
}, ensure_ascii=False)
|
||||||
logger.error(msg)
|
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||||||
yield json.dumps({
|
else:
|
||||||
"code": 500,
|
kb_name, file_name, error = result
|
||||||
"msg": msg,
|
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||||||
})
|
logger.error(msg)
|
||||||
i += 1
|
yield json.dumps({
|
||||||
if not not_refresh_vs_cache:
|
"code": 500,
|
||||||
kb.save_vector_store()
|
"msg": msg,
|
||||||
|
})
|
||||||
|
i += 1
|
||||||
|
if not not_refresh_vs_cache:
|
||||||
|
kb.save_vector_store()
|
||||||
|
|
||||||
return EventSourceResponse(output())
|
return EventSourceResponse(output())
|
||||||
|
|||||||
@ -5,6 +5,11 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
from typing import List, Union, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
|
DEFAULT_EMBEDDING_MODEL, KB_INFO, logger)
|
||||||
|
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
|
||||||
from chatchat.server.db.repository.knowledge_base_repository import (
|
from chatchat.server.db.repository.knowledge_base_repository import (
|
||||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||||
load_kb_from_db, get_kb_detail,
|
load_kb_from_db, get_kb_detail,
|
||||||
@ -14,18 +19,12 @@ from chatchat.server.db.repository.knowledge_file_repository import (
|
|||||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
||||||
list_docs_from_db,
|
list_docs_from_db,
|
||||||
)
|
)
|
||||||
|
|
||||||
from chatchat.configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
|
||||||
DEFAULT_EMBEDDING_MODEL, KB_INFO)
|
|
||||||
from chatchat.server.knowledge_base.utils import (
|
from chatchat.server.knowledge_base.utils import (
|
||||||
get_kb_path, get_doc_path, KnowledgeFile,
|
get_kb_path, get_doc_path, KnowledgeFile,
|
||||||
list_kbs_from_folder, list_files_from_folder,
|
list_kbs_from_folder, list_files_from_folder,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import List, Union, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema
|
from chatchat.server.utils import check_embed_model as _check_embed_model
|
||||||
|
|
||||||
class SupportedVSType:
|
class SupportedVSType:
|
||||||
FAISS = 'faiss'
|
FAISS = 'faiss'
|
||||||
@ -41,10 +40,11 @@ class KBService(ABC):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
|
kb_info: str = None,
|
||||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
):
|
):
|
||||||
self.kb_name = knowledge_base_name
|
self.kb_name = knowledge_base_name
|
||||||
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
self.kb_info = kb_info or KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||||
self.embed_model = embed_model
|
self.embed_model = embed_model
|
||||||
self.kb_path = get_kb_path(self.kb_name)
|
self.kb_path = get_kb_path(self.kb_name)
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
self.doc_path = get_doc_path(self.kb_name)
|
||||||
@ -59,6 +59,13 @@ class KBService(ABC):
|
|||||||
'''
|
'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def check_embed_model(self, error_msg: str) -> bool:
|
||||||
|
if not _check_embed_model(self.embed_model):
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
def create_kb(self):
|
def create_kb(self):
|
||||||
"""
|
"""
|
||||||
创建知识库
|
创建知识库
|
||||||
@ -93,6 +100,9 @@ class KBService(ABC):
|
|||||||
向知识库添加文件
|
向知识库添加文件
|
||||||
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
|
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
|
||||||
"""
|
"""
|
||||||
|
if not self.check_embed_model(f"could not add docs because failed to access embed model."):
|
||||||
|
return False
|
||||||
|
|
||||||
if docs:
|
if docs:
|
||||||
custom_docs = True
|
custom_docs = True
|
||||||
else:
|
else:
|
||||||
@ -143,6 +153,9 @@ class KBService(ABC):
|
|||||||
使用content中的文件更新向量库
|
使用content中的文件更新向量库
|
||||||
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
|
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
|
||||||
"""
|
"""
|
||||||
|
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
|
||||||
|
return False
|
||||||
|
|
||||||
if os.path.exists(kb_file.filepath):
|
if os.path.exists(kb_file.filepath):
|
||||||
self.delete_doc(kb_file, **kwargs)
|
self.delete_doc(kb_file, **kwargs)
|
||||||
return self.add_doc(kb_file, docs=docs, **kwargs)
|
return self.add_doc(kb_file, docs=docs, **kwargs)
|
||||||
@ -162,6 +175,8 @@ class KBService(ABC):
|
|||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
score_threshold: float = SCORE_THRESHOLD,
|
score_threshold: float = SCORE_THRESHOLD,
|
||||||
) ->List[Document]:
|
) ->List[Document]:
|
||||||
|
if not self.check_embed_model(f"could not search docs because failed to access embed model."):
|
||||||
|
return []
|
||||||
docs = self.do_search(query, top_k, score_threshold)
|
docs = self.do_search(query, top_k, score_threshold)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@ -176,6 +191,9 @@ class KBService(ABC):
|
|||||||
传入参数为: {doc_id: Document, ...}
|
传入参数为: {doc_id: Document, ...}
|
||||||
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
|
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
|
||||||
'''
|
'''
|
||||||
|
if not self.check_embed_model(f"could not update docs because failed to access embed model."):
|
||||||
|
return False
|
||||||
|
|
||||||
self.del_doc_by_ids(list(docs.keys()))
|
self.del_doc_by_ids(list(docs.keys()))
|
||||||
docs = []
|
docs = []
|
||||||
ids = []
|
ids = []
|
||||||
@ -282,31 +300,32 @@ class KBServiceFactory:
|
|||||||
def get_service(kb_name: str,
|
def get_service(kb_name: str,
|
||||||
vector_store_type: Union[str, SupportedVSType],
|
vector_store_type: Union[str, SupportedVSType],
|
||||||
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
embed_model: str = DEFAULT_EMBEDDING_MODEL,
|
||||||
|
kb_info: str = None,
|
||||||
) -> KBService:
|
) -> KBService:
|
||||||
if isinstance(vector_store_type, str):
|
if isinstance(vector_store_type, str):
|
||||||
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||||
|
params = {"knowledge_base_name": kb_name, "embed_model": embed_model, "kb_info": kb_info}
|
||||||
if SupportedVSType.FAISS == vector_store_type:
|
if SupportedVSType.FAISS == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
from chatchat.server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||||
return FaissKBService(kb_name, embed_model=embed_model)
|
return FaissKBService(**params)
|
||||||
elif SupportedVSType.PG == vector_store_type:
|
elif SupportedVSType.PG == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
from chatchat.server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
||||||
return PGKBService(kb_name, embed_model=embed_model)
|
return PGKBService(**params)
|
||||||
elif SupportedVSType.MILVUS == vector_store_type:
|
elif SupportedVSType.MILVUS == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
return MilvusKBService(kb_name, embed_model=embed_model)
|
return MilvusKBService(**params)
|
||||||
elif SupportedVSType.ZILLIZ == vector_store_type:
|
elif SupportedVSType.ZILLIZ == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
|
from chatchat.server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
|
||||||
return ZillizKBService(kb_name, embed_model=embed_model)
|
return ZillizKBService(**params)
|
||||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
elif SupportedVSType.DEFAULT == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
from chatchat.server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
return MilvusKBService(kb_name,
|
return MilvusKBService(**params) # other milvus parameters are set in model_config.kbs_config
|
||||||
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
|
||||||
elif SupportedVSType.ES == vector_store_type:
|
elif SupportedVSType.ES == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService
|
from chatchat.server.knowledge_base.kb_service.es_kb_service import ESKBService
|
||||||
return ESKBService(kb_name, embed_model=embed_model)
|
return ESKBService(**params)
|
||||||
elif SupportedVSType.CHROMADB == vector_store_type:
|
elif SupportedVSType.CHROMADB == vector_store_type:
|
||||||
from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
|
from chatchat.server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
|
||||||
return ChromaKBService(kb_name, embed_model=embed_model)
|
return ChromaKBService(**params)
|
||||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||||
from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
from chatchat.server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||||
return DefaultKBService(kb_name)
|
return DefaultKBService(kb_name)
|
||||||
|
|||||||
@ -42,55 +42,59 @@ def recreate_summary_vector_store(
|
|||||||
if not kb.exists() and not allow_empty_kb:
|
if not kb.exists() and not allow_empty_kb:
|
||||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||||
else:
|
else:
|
||||||
# 重新创建知识库
|
error_msg = f"could not recreate summary vector store because failed to access embed model."
|
||||||
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
if not kb.check_embed_model(error_msg):
|
||||||
kb_summary.drop_kb_summary()
|
yield {"code": 404, "msg": error_msg}
|
||||||
kb_summary.create_kb_summary()
|
else:
|
||||||
|
# 重新创建知识库
|
||||||
|
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||||
|
kb_summary.drop_kb_summary()
|
||||||
|
kb_summary.create_kb_summary()
|
||||||
|
|
||||||
llm = get_ChatOpenAI(
|
llm = get_ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
local_wrap=True,
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
reduce_llm = get_ChatOpenAI(
|
reduce_llm = get_ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
local_wrap=True,
|
local_wrap=True,
|
||||||
)
|
)
|
||||||
# 文本摘要适配器
|
# 文本摘要适配器
|
||||||
summary = SummaryAdapter.form_summary(llm=llm,
|
summary = SummaryAdapter.form_summary(llm=llm,
|
||||||
reduce_llm=reduce_llm,
|
reduce_llm=reduce_llm,
|
||||||
overlap_size=OVERLAP_SIZE)
|
overlap_size=OVERLAP_SIZE)
|
||||||
files = list_files_from_folder(knowledge_base_name)
|
files = list_files_from_folder(knowledge_base_name)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
for i, file_name in enumerate(files):
|
for i, file_name in enumerate(files):
|
||||||
|
|
||||||
doc_infos = kb.list_docs(file_name=file_name)
|
doc_infos = kb.list_docs(file_name=file_name)
|
||||||
docs = summary.summarize(file_description=file_description,
|
docs = summary.summarize(file_description=file_description,
|
||||||
docs=doc_infos)
|
docs=doc_infos)
|
||||||
|
|
||||||
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||||
if status_kb_summary:
|
if status_kb_summary:
|
||||||
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
|
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
|
||||||
yield json.dumps({
|
yield json.dumps({
|
||||||
"code": 200,
|
"code": 200,
|
||||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||||
"total": len(files),
|
"total": len(files),
|
||||||
"finished": i + 1,
|
"finished": i + 1,
|
||||||
"doc": file_name,
|
"doc": file_name,
|
||||||
}, ensure_ascii=False)
|
}, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
yield json.dumps({
|
yield json.dumps({
|
||||||
"code": 500,
|
"code": 500,
|
||||||
"msg": msg,
|
"msg": msg,
|
||||||
})
|
})
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return EventSourceResponse(output())
|
return EventSourceResponse(output())
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from chatchat.configs import (
|
from chatchat.configs import (
|
||||||
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
DEFAULT_EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE
|
CHUNK_SIZE, OVERLAP_SIZE, logger, log_verbose
|
||||||
)
|
)
|
||||||
from chatchat.server.knowledge_base.utils import (
|
from chatchat.server.knowledge_base.utils import (
|
||||||
get_file_path, list_kbs_from_folder,
|
get_file_path, list_kbs_from_folder,
|
||||||
|
|||||||
@ -244,6 +244,16 @@ def get_Embeddings(
|
|||||||
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
|
logger.error(f"failed to create Embeddings for model: {embed_model}.", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def check_embed_model(embed_model: str=DEFAULT_EMBEDDING_MODEL) -> bool:
|
||||||
|
embeddings = get_Embeddings(embed_model=embed_model)
|
||||||
|
try:
|
||||||
|
embeddings.embed_query("this is a test")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"failed to access embed model '{embed_model}': {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_OpenAIClient(
|
def get_OpenAIClient(
|
||||||
platform_name: str = None,
|
platform_name: str = None,
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
@ -696,7 +706,7 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]:
|
|||||||
'''
|
'''
|
||||||
创建一个临时目录,返回(路径,文件夹名称)
|
创建一个临时目录,返回(路径,文件夹名称)
|
||||||
'''
|
'''
|
||||||
from chatchat.configs.basic_config import BASE_TEMP_DIR
|
from chatchat.configs import BASE_TEMP_DIR
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
if id is not None: # 如果指定的临时目录已存在,直接返回
|
if id is not None: # 如果指定的临时目录已存在,直接返回
|
||||||
|
|||||||
@ -10,6 +10,7 @@ packages = [
|
|||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
chatchat = 'chatchat.startup:main'
|
chatchat = 'chatchat.startup:main'
|
||||||
|
chatchat-kb = 'chatchat.init_database:main'
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<3.12,!=3.9.7"
|
python = ">=3.8.1,<3.12,!=3.9.7"
|
||||||
|
|||||||
@ -133,9 +133,11 @@ model_names = list(regs[model_type].keys())
|
|||||||
model_name = cols[1].selectbox("模型名称:", model_names)
|
model_name = cols[1].selectbox("模型名称:", model_names)
|
||||||
|
|
||||||
cur_reg = regs[model_type][model_name]["reg"]
|
cur_reg = regs[model_type][model_name]["reg"]
|
||||||
|
model_format = None
|
||||||
|
model_quant = None
|
||||||
|
|
||||||
if model_type == "LLM":
|
if model_type == "LLM":
|
||||||
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
|
cur_family = xf_llm.LLMFamilyV1.model_validate(cur_reg)
|
||||||
cur_spec = None
|
cur_spec = None
|
||||||
model_formats = []
|
model_formats = []
|
||||||
for spec in cur_reg["model_specs"]:
|
for spec in cur_reg["model_specs"]:
|
||||||
@ -217,7 +219,7 @@ cols = st.columns(3)
|
|||||||
|
|
||||||
if cols[0].button("设置模型缓存"):
|
if cols[0].button("设置模型缓存"):
|
||||||
if os.path.isabs(local_path) and os.path.isdir(local_path):
|
if os.path.isabs(local_path) and os.path.isdir(local_path):
|
||||||
cur_spec.model_uri = local_path
|
cur_spec.__dict__["model_uri"] = local_path # embedding spec has no attribute model_uri
|
||||||
if os.path.isdir(cache_dir):
|
if os.path.isdir(cache_dir):
|
||||||
os.rmdir(cache_dir)
|
os.rmdir(cache_dir)
|
||||||
if model_type == "LLM":
|
if model_type == "LLM":
|
||||||
@ -250,10 +252,10 @@ if cols[2].button("注册为自定义模型"):
|
|||||||
if model_type == "LLM":
|
if model_type == "LLM":
|
||||||
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
|
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
|
||||||
cur_family.model_family = "other"
|
cur_family.model_family = "other"
|
||||||
model_definition = cur_family.json(indent=2, ensure_ascii=False)
|
model_definition = cur_family.model_dump_json(indent=2, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
|
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
|
||||||
model_definition = cur_spec.json(indent=2, ensure_ascii=False)
|
model_definition = cur_spec.model_dump_json(indent=2, ensure_ascii=False)
|
||||||
client.register_model(
|
client.register_model(
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
model=model_definition,
|
model=model_definition,
|
||||||
@ -262,4 +264,3 @@ if cols[2].button("注册为自定义模型"):
|
|||||||
st.rerun()
|
st.rerun()
|
||||||
else:
|
else:
|
||||||
st.error("必须输入存在的绝对路径")
|
st.error("必须输入存在的绝对路径")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user