2024-06-11 21:02:19 +08:00

337 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
from dataclasses import dataclass
from pathlib import Path
import sys
import logging
from typing import Any, Optional, Dict, Tuple
sys.path.append(str(Path(__file__).parent))
import _core_config as core_config
from _basic_config import config_basic_workspace
class ConfigKb(core_config.Config):
DEFAULT_KNOWLEDGE_BASE: Optional[str] = None
"""默认使用的知识库"""
DEFAULT_VS_TYPE: Optional[str] = None
"""默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es"""
CACHED_VS_NUM: Optional[int] = None
"""缓存向量库数量针对FAISS"""
CACHED_MEMO_VS_NUM: Optional[int] = None
"""缓存临时向量库数量针对FAISS用于文件对话"""
CHUNK_SIZE: Optional[int] = None
"""知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)"""
OVERLAP_SIZE: Optional[int] = None
"""知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)"""
VECTOR_SEARCH_TOP_K: Optional[int] = None
"""知识库匹配向量数量"""
SCORE_THRESHOLD: Optional[float] = None
"""知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右"""
DEFAULT_SEARCH_ENGINE: Optional[str] = None
"""默认搜索引擎。可选bing, duckduckgo, metaphor"""
SEARCH_ENGINE_TOP_K: Optional[int] = None
"""搜索引擎匹配结题数量"""
ZH_TITLE_ENHANCE: Optional[bool] = None
"""是否开启中文标题加强,以及标题增强的相关配置"""
PDF_OCR_THRESHOLD: Optional[Tuple[float, float]] = None
"""
PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
"""
KB_INFO: Optional[Dict[str, str]] = None
"""每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。"""
KB_ROOT_PATH: Optional[str] = None
"""知识库默认存储路径"""
DB_ROOT_PATH: Optional[str] = None
"""数据库默认存储路径。如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。"""
SQLALCHEMY_DATABASE_URI: Optional[str] = None
"""数据库连接URI"""
kbs_config: Optional[Dict[str, Dict[str, Any]]] = None
"""可选向量库类型及对应配置"""
text_splitter_dict: Optional[Dict[str, Dict[str, Any]]] = None
"""TextSplitter配置项如果你不明白其中的含义就不要修改。"""
TEXT_SPLITTER_NAME: Optional[str] = None
"""TEXT_SPLITTER 名称"""
EMBEDDING_KEYWORD_FILE: Optional[str] = None
"""Embedding模型定制词语的词表文件"""
@classmethod
def class_name(cls) -> str:
return cls.__name__
def __str__(self):
return self.to_json()
@dataclass
class ConfigKbFactory(core_config.ConfigFactory[ConfigKb]):
"""ConfigKb 配置工厂类"""
def __init__(self):
# 默认使用的知识库
self.DEFAULT_KNOWLEDGE_BASE = "samples"
# 默认向量库/全文检索引擎类型。可选faiss, milvus(离线) & zilliz(在线), pgvector,全文检索引擎es
self.DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量针对FAISS
self.CACHED_VS_NUM = 1
# 缓存临时向量库数量针对FAISS用于文件对话
self.CACHED_MEMO_VS_NUM = 10
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
self.CHUNK_SIZE = 250
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
self.OVERLAP_SIZE = 50
# 知识库匹配向量数量
self.VECTOR_SEARCH_TOP_K = 3
# 知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右
self.SCORE_THRESHOLD = 1
# 默认搜索引擎。可选bing, duckduckgo, metaphor
self.DEFAULT_SEARCH_ENGINE = "duckduckgo"
# 搜索引擎匹配结题数量
self.SEARCH_ENGINE_TOP_K = 3
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断判断哪些文本为标题并在metadata中进行标记
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
self.ZH_TITLE_ENHANCE = False
# PDF OCR 控制:只对宽高超过页面一定比例(图片宽/页面宽,图片高/页面高)的图片进行 OCR。
# 这样可以避免 PDF 中一些小图片的干扰,提高非扫描版 PDF 处理速度
self.PDF_OCR_THRESHOLD = (0.6, 0.6)
# 每个知识库的初始化介绍用于在初始化知识库时显示和Agent调用没写则没有介绍不会被Agent调用。
self.KB_INFO = {
"samples": "关于本项目issue的解答",
}
# 通常情况下不需要更改以下内容
# 知识库默认存储路径
self.KB_ROOT_PATH = os.path.join(config_basic_workspace.get_config().DATA_PATH, "knowledge_base")
if not os.path.exists(self.KB_ROOT_PATH):
os.mkdir(self.KB_ROOT_PATH)
# 数据库默认存储路径。
# 如果使用sqlite可以直接修改DB_ROOT_PATH如果使用其它数据库请直接修改SQLALCHEMY_DATABASE_URI。
self.DB_ROOT_PATH = os.path.join(self.KB_ROOT_PATH, "info.db")
self.SQLALCHEMY_DATABASE_URI = f"sqlite:///{self.DB_ROOT_PATH}"
# 可选向量库类型及对应配置
self.kbs_config = {
"faiss": {
},
"milvus": {
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
},
"zilliz": {
"host": "in01-a7ce524e41e3935.ali-cn-hangzhou.vectordb.zilliz.com.cn",
"port": "19530",
"user": "",
"password": "",
"secure": True,
},
"pg": {
"connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat",
},
"es": {
"host": "127.0.0.1",
"port": "9200",
"index_name": "test_index",
"user": "",
"password": ""
},
"milvus_kwargs": {
"search_params": {"metric_type": "L2"}, #在此处增加search_params
"index_params": {"metric_type": "L2", "index_type": "HNSW"} # 在此处增加index_params
},
"chromadb": {}
}
# TextSplitter配置项如果你不明白其中的含义就不要修改。
self.text_splitter_dict = {
"ChineseRecursiveTextSplitter": {
"source": "huggingface", # 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "",
},
"SpacyTextSplitter": {
"source": "huggingface",
"tokenizer_name_or_path": "gpt2",
},
"RecursiveCharacterTextSplitter": {
"source": "tiktoken",
"tokenizer_name_or_path": "cl100k_base",
},
"MarkdownHeaderTextSplitter": {
"headers_to_split_on":
[
("#", "head1"),
("##", "head2"),
("###", "head3"),
("####", "head4"),
]
},
}
# TEXT_SPLITTER 名称
self.TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"
# Embedding模型定制词语的词表文件
self.EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt"
def get_config(self) -> ConfigKb:
config = ConfigKb()
for key, value in self.__dict__.items():
setattr(config, key, value)
return config
class ConfigKbWorkSpace(core_config.ConfigWorkSpace[ConfigKbFactory, ConfigKb]):
"""
工作空间的配置预设提供ConfigKb建造方法产生实例。
"""
config_factory_cls = ConfigKbFactory
def __init__(self):
super().__init__()
def _build_config_factory(self, config_json: Any) -> ConfigKbFactory:
_config_factory = self.config_factory_cls()
if config_json.get("DEFAULT_KNOWLEDGE_BASE"):
_config_factory.DEFAULT_KNOWLEDGE_BASE = config_json.get("DEFAULT_KNOWLEDGE_BASE")
if config_json.get("DEFAULT_VS_TYPE"):
_config_factory.DEFAULT_VS_TYPE = config_json.get("DEFAULT_VS_TYPE")
if config_json.get("CACHED_VS_NUM"):
_config_factory.CACHED_VS_NUM = config_json.get("CACHED_VS_NUM")
if config_json.get("CACHED_MEMO_VS_NUM"):
_config_factory.CACHED_MEMO_VS_NUM = config_json.get("CACHED_MEMO_VS_NUM")
if config_json.get("CHUNK_SIZE"):
_config_factory.CHUNK_SIZE = config_json.get("CHUNK_SIZE")
if config_json.get("OVERLAP_SIZE"):
_config_factory.OVERLAP_SIZE = config_json.get("OVERLAP_SIZE")
if config_json.get("VECTOR_SEARCH_TOP_K"):
_config_factory.VECTOR_SEARCH_TOP_K = config_json.get("VECTOR_SEARCH_TOP_K")
if config_json.get("SCORE_THRESHOLD"):
_config_factory.SCORE_THRESHOLD = config_json.get("SCORE_THRESHOLD")
if config_json.get("DEFAULT_SEARCH_ENGINE"):
_config_factory.DEFAULT_SEARCH_ENGINE = config_json.get("DEFAULT_SEARCH_ENGINE")
if config_json.get("SEARCH_ENGINE_TOP_K"):
_config_factory.SEARCH_ENGINE_TOP_K = config_json.get("SEARCH_ENGINE_TOP_K")
if config_json.get("ZH_TITLE_ENHANCE"):
_config_factory.ZH_TITLE_ENHANCE = config_json.get("ZH_TITLE_ENHANCE")
if config_json.get("PDF_OCR_THRESHOLD"):
_config_factory.PDF_OCR_THRESHOLD = config_json.get("PDF_OCR_THRESHOLD")
if config_json.get("KB_INFO"):
_config_factory.KB_INFO = config_json.get("KB_INFO")
if config_json.get("KB_ROOT_PATH"):
_config_factory.KB_ROOT_PATH = config_json.get("KB_ROOT_PATH")
if config_json.get("DB_ROOT_PATH"):
_config_factory.DB_ROOT_PATH = config_json.get("DB_ROOT_PATH")
if config_json.get("SQLALCHEMY_DATABASE_URI"):
_config_factory.SQLALCHEMY_DATABASE_URI = config_json.get("SQLALCHEMY_DATABASE_URI")
if config_json.get("TEXT_SPLITTER_NAME"):
_config_factory.TEXT_SPLITTER_NAME = config_json.get("TEXT_SPLITTER_NAME")
if config_json.get("EMBEDDING_KEYWORD_FILE"):
_config_factory.EMBEDDING_KEYWORD_FILE = config_json.get("EMBEDDING_KEYWORD_FILE")
return _config_factory
@classmethod
def get_type(cls) -> str:
return ConfigKb.class_name()
def get_config(self) -> ConfigKb:
return self._config_factory.get_config()
def set_default_knowledge_base(self, kb_name: str):
self._config_factory.DEFAULT_KNOWLEDGE_BASE = kb_name
self.store_config()
def set_default_vs_type(self, vs_type: str):
self._config_factory.DEFAULT_VS_TYPE = vs_type
self.store_config()
def set_cached_vs_num(self, cached_vs_num: int):
self._config_factory.CACHED_VS_NUM = cached_vs_num
self.store_config()
def set_cached_memo_vs_num(self, cached_memo_vs_num: int):
self._config_factory.CACHED_MEMO_VS_NUM = cached_memo_vs_num
self.store_config()
def set_chunk_size(self, chunk_size: int):
self._config_factory.CHUNK_SIZE = chunk_size
self.store_config()
def set_overlap_size(self, overlap_size: int):
self._config_factory.OVERLAP_SIZE = overlap_size
self.store_config()
def set_vector_search_top_k(self, vector_search_top_k: int):
self._config_factory.VECTOR_SEARCH_TOP_K = vector_search_top_k
self.store_config()
def set_score_threshold(self, score_threshold: float):
self._config_factory.SCORE_THRESHOLD = score_threshold
self.store_config()
def set_default_search_engine(self, default_search_engine: str):
self._config_factory.DEFAULT_SEARCH_ENGINE = default_search_engine
self.store_config()
def set_search_engine_top_k(self, search_engine_top_k: int):
self._config_factory.SEARCH_ENGINE_TOP_K = search_engine_top_k
self.store_config()
def set_zh_title_enhance(self, zh_title_enhance: bool):
self._config_factory.ZH_TITLE_ENHANCE = zh_title_enhance
self.store_config()
def set_pdf_ocr_threshold(self, pdf_ocr_threshold: Tuple[float, float]):
self._config_factory.PDF_OCR_THRESHOLD = pdf_ocr_threshold
self.store_config()
def set_kb_info(self, kb_info: Dict[str, str]):
self._config_factory.KB_INFO = kb_info
self.store_config()
def set_kb_root_path(self, kb_root_path: str):
self._config_factory.KB_ROOT_PATH = kb_root_path
self.store_config()
def set_db_root_path(self, db_root_path: str):
self._config_factory.DB_ROOT_PATH = db_root_path
self.store_config()
def set_sqlalchemy_database_uri(self, sqlalchemy_database_uri: str):
self._config_factory.SQLALCHEMY_DATABASE_URI = sqlalchemy_database_uri
self.store_config()
def set_text_splitter_name(self, text_splitter_name: str):
self._config_factory.TEXT_SPLITTER_NAME = text_splitter_name
self.store_config()
def set_embedding_keyword_file(self, embedding_keyword_file: str):
self._config_factory.EMBEDDING_KEYWORD_FILE = embedding_keyword_file
self.store_config()
config_kb_workspace: ConfigKbWorkSpace = ConfigKbWorkSpace()