2024-06-12 18:40:52 +08:00

449 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import logging
import sys
from pathlib import Path
from typing import Any, Optional, List, Dict
from dataclasses import dataclass
sys.path.append(str(Path(__file__).parent))
import _core_config as core_config
logger = logging.getLogger()
class ConfigModel(core_config.Config):
DEFAULT_LLM_MODEL: Optional[str] = None
"""默认选用的 LLM 名称"""
DEFAULT_EMBEDDING_MODEL: Optional[str] = None
"""默认选用的 Embedding 名称"""
Agent_MODEL: Optional[str] = None
"""AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODELS[0])"""
HISTORY_LEN: Optional[int] = None
"""历史对话轮数"""
MAX_TOKENS: Optional[int] = None
"""大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度"""
TEMPERATURE: Optional[float] = None
"""LLM通用对话参数"""
SUPPORT_AGENT_MODELS: Optional[List[str]] = None
"""支持的Agent模型"""
LLM_MODEL_CONFIG: Optional[Dict[str, Dict[str, Any]]] = None
"""LLM模型配置包括了不同模态初始化参数"""
MODEL_PLATFORMS: Optional[List[Dict[str, Any]]] = None
"""模型平台配置"""
MODEL_PROVIDERS_CFG_PATH_CONFIG: Optional[str] = None
"""模型平台配置文件路径"""
MODEL_PROVIDERS_CFG_HOST: Optional[str] = None
"""模型平台配置服务host"""
MODEL_PROVIDERS_CFG_PORT: Optional[int] = None
"""模型平台配置服务port"""
TOOL_CONFIG: Optional[Dict[str, Any]] = None
"""工具配置项"""
@classmethod
def class_name(cls) -> str:
return cls.__name__
def __str__(self):
return self.to_json()
@dataclass
class ConfigModelFactory(core_config.ConfigFactory[ConfigModel]):
"""ConfigModel工厂类"""
def __init__(self):
# 默认选用的 LLM 名称
self.DEFAULT_LLM_MODEL = "chatglm3-6b"
# 默认选用的 Embedding 名称
self.DEFAULT_EMBEDDING_MODEL = "bge-large-zh-v1.5"
# AgentLM模型的名称 (可以不指定指定之后就锁定进入Agent之后的Chain的模型不指定就是LLM_MODELS[0])
self.Agent_MODEL = None
# 历史对话轮数
self.HISTORY_LEN = 3
# 大模型最长支持的长度,如果不填写,则使用模型默认的最大长度,如果填写,则为用户设定的最大长度
self.MAX_TOKENS = None
# LLM通用对话参数
self.TEMPERATURE = 0.7
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
self.SUPPORT_AGENT_MODELS = [
"chatglm3-6b",
"openai-api",
"Qwen-14B-Chat",
"Qwen-7B-Chat",
"qwen-turbo",
]
self.MODEL_PROVIDERS_CFG_PATH_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"model_providers.yaml")
self.MODEL_PROVIDERS_CFG_HOST = "127.0.0.1"
self.MODEL_PROVIDERS_CFG_PORT = 20000
# 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力启动后下面变量会自动增加相应的平台
# ### 如果您已经有了一个openai endpoint的能力的地址可以在这里直接配置
# - platform_name 可以任意填写,不要重复即可
# - platform_type 以后可能根据平台类型做一些功能区分,与platform_name一致即可
# - 将框架部署的模型填写到对应列表即可。不同框架可以加载同名模型,项目会自动做负载均衡。
# 创建一个全局的共享字典
self.MODEL_PLATFORMS = [
{
"platform_name": "oneapi",
"platform_type": "oneapi",
"api_base_url": "http://127.0.0.1:3000/v1",
"api_key": "sk-",
"api_concurrencies": 5,
"llm_models": [
# 智谱 API
"chatglm_pro",
"chatglm_turbo",
"chatglm_std",
"chatglm_lite",
# 千问 API
"qwen-turbo",
"qwen-plus",
"qwen-max",
"qwen-max-longcontext",
# 千帆 API
"ERNIE-Bot",
"ERNIE-Bot-turbo",
"ERNIE-Bot-4",
# 星火 API
"SparkDesk",
],
"embed_models": [
# 千问 API
"text-embedding-v1",
# 千帆 API
"Embedding-V1",
],
"image_models": [],
"reranking_models": [],
"speech2text_models": [],
"tts_models": [],
},
]
# 工具配置项
self.TOOL_CONFIG = {
"search_local_knowledgebase": {
"use": False,
"top_k": 3,
"score_threshold": 1.0,
"conclude_prompt": {
"with_result":
'<指令>根据已知信息,简洁和专业的来回答问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题"'
'不允许在答案中添加编造成分,答案请使用中文。 </指令>\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"without_result":
'请你根据我的提问回答我的问题:\n'
'{{ question }}\n'
'请注意,你必须在回答结束后强调,你的回答是根据你的经验回答而不是参考资料回答的。\n',
}
},
"search_internet": {
"use": False,
"search_engine_name": "bing",
"search_engine_config":
{
"bing": {
"result_len": 3,
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
"bing_key": "",
},
"metaphor": {
"result_len": 3,
"metaphor_api_key": "",
"split_result": False,
"chunk_size": 500,
"chunk_overlap": 0,
},
"duckduckgo": {
"result_len": 3
}
},
"top_k": 10,
"verbose": "Origin",
"conclude_prompt":
"<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
"</指令>\n<已知信息>{{ context }}</已知信息>\n"
"<问题>\n"
"{{ question }}\n"
"</问题>\n"
},
"arxiv": {
"use": False,
},
"shell": {
"use": False,
},
"weather_check": {
"use": False,
"api_key": "S8vrB4U_-c5mvAMiK",
},
"search_youtube": {
"use": False,
},
"wolfram": {
"use": False,
"appid": "",
},
"calculate": {
"use": False,
},
"vqa_processor": {
"use": False,
"model_path": "your model path",
"tokenizer_path": "your tokenizer path",
"device": "cuda:1"
},
"aqa_processor": {
"use": False,
"model_path": "your model path",
"tokenizer_path": "yout tokenizer path",
"device": "cuda:2"
},
"text2images": {
"use": False,
},
# text2sql使用建议
# 1、因大模型生成的sql可能与预期有偏差请务必在测试环境中进行充分测试、评估
# 2、生产环境中对于查询操作由于不确定查询效率推荐数据库采用主从数据库架构让text2sql连接从数据库防止可能的慢查询影响主业务
# 3、对于写操作应保持谨慎如不需要写操作设置read_only为True,最好再从数据库层面收回数据库用户的写权限,防止用户通过自然语言对数据库进行修改操作;
# 4、text2sql与大模型在意图理解、sql转换等方面的能力有关可切换不同大模型进行测试
# 5、数据库表名、字段名应与其实际作用保持一致、容易理解且应对数据库表名、字段进行详细的备注说明帮助大模型更好理解数据库结构
# 6、若现有数据库表名难于让大模型理解可配置下面table_comments字段补充说明某些表的作用。
"text2sql": {
"use": False,
# SQLAlchemy连接字符串支持的数据库有
# crate、duckdb、googlesql、mssql、mysql、mariadb、oracle、postgresql、sqlite、clickhouse、prestodb
# 不同的数据库请查询SQLAlchemy修改sqlalchemy_connect_str配置对应的数据库连接如sqlite为sqlite:///数据库文件路径下面示例为mysql
# 如提示缺少对应数据库的驱动请自行通过poetry安装
"sqlalchemy_connect_str": "mysql+pymysql://用户名:密码@主机地址/数据库名称e",
# 务必评估是否需要开启read_only,开启后会对sql语句进行检查请确认text2sql.py中的intercept_sql拦截器是否满足你使用的数据库只读要求
# 优先推荐从数据库层面对用户权限进行限制
"read_only": False,
# 限定返回的行数
"top_k": 50,
# 是否返回中间步骤
"return_intermediate_steps": True,
# 如果想指定特定表,请填写表名称,如["sys_user","sys_dept"],不填写走智能判断应该使用哪些表
"table_names": [],
# 对表名进行额外说明辅助大模型更好的判断应该使用哪些表尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。
"table_comments": {
# 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明
# "tableA":"这是一个用户表,存储了用户的基本信息",
# "tanleB":"角色表",
}
},
}
self._init_llm_work_config()
def _init_llm_work_config(self):
"""初始化知识库runtime的一些配置"""
self.LLM_MODEL_CONFIG = {
# 意图识别不需要输出,模型后台知道就行
"preprocess_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.05,
"max_tokens": 4096,
"history_len": 100,
"prompt_name": "default",
"callbacks": False
},
},
"llm_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.9,
"max_tokens": 4096,
"history_len": 10,
"prompt_name": "default",
"callbacks": True
},
},
"action_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.01,
"max_tokens": 4096,
"prompt_name": "ChatGLM3",
"callbacks": True
},
},
"postprocess_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.01,
"max_tokens": 4096,
"prompt_name": "default",
"callbacks": True
}
},
"image_model": {
"sd-turbo": {
"size": "256*256",
}
}
}
def default_llm_model(self, llm_model: str):
self.DEFAULT_LLM_MODEL = llm_model
def default_embedding_model(self, embedding_model: str):
self.DEFAULT_EMBEDDING_MODEL = embedding_model
def agent_model(self, agent_model: str):
self.Agent_MODEL = agent_model
def history_len(self, history_len: int):
self.HISTORY_LEN = history_len
def max_tokens(self, max_tokens: int):
self.MAX_TOKENS = max_tokens
def temperature(self, temperature: float):
self.TEMPERATURE = temperature
def support_agent_models(self, support_agent_models: List[str]):
self.SUPPORT_AGENT_MODELS = support_agent_models
def model_providers_cfg_path_config(self, model_providers_cfg_path_config: str):
self.MODEL_PROVIDERS_CFG_PATH_CONFIG = model_providers_cfg_path_config
def model_providers_cfg_host(self, model_providers_cfg_host: str):
self.MODEL_PROVIDERS_CFG_HOST = model_providers_cfg_host
def model_providers_cfg_port(self, model_providers_cfg_port: int):
self.MODEL_PROVIDERS_CFG_PORT = model_providers_cfg_port
def model_platforms(self, model_platforms: List[Dict[str, Any]]):
self.MODEL_PLATFORMS = model_platforms
def tool_config(self, tool_config: Dict[str, Any]):
self.TOOL_CONFIG = tool_config
def get_config(self) -> ConfigModel:
config = ConfigModel()
config.DEFAULT_LLM_MODEL = self.DEFAULT_LLM_MODEL
config.DEFAULT_EMBEDDING_MODEL = self.DEFAULT_EMBEDDING_MODEL
config.Agent_MODEL = self.Agent_MODEL
config.HISTORY_LEN = self.HISTORY_LEN
config.MAX_TOKENS = self.MAX_TOKENS
config.TEMPERATURE = self.TEMPERATURE
config.SUPPORT_AGENT_MODELS = self.SUPPORT_AGENT_MODELS
config.LLM_MODEL_CONFIG = self.LLM_MODEL_CONFIG
config.MODEL_PLATFORMS = self.MODEL_PLATFORMS
config.MODEL_PROVIDERS_CFG_PATH_CONFIG = self.MODEL_PROVIDERS_CFG_PATH_CONFIG
config.MODEL_PROVIDERS_CFG_HOST = self.MODEL_PROVIDERS_CFG_HOST
config.MODEL_PROVIDERS_CFG_PORT = self.MODEL_PROVIDERS_CFG_PORT
config.TOOL_CONFIG = self.TOOL_CONFIG
return config
class ConfigModelWorkSpace(core_config.ConfigWorkSpace[ConfigModelFactory, ConfigModel]):
"""
工作空间的配置预设, 提供ConfigModel建造方法产生实例。
"""
config_factory_cls = ConfigModelFactory
def __init__(self):
super().__init__()
def _build_config_factory(self, config_json: Any) -> ConfigModelFactory:
_config_factory = self.config_factory_cls()
if config_json.get("DEFAULT_LLM_MODEL"):
_config_factory.default_llm_model(config_json.get("DEFAULT_LLM_MODEL"))
if config_json.get("DEFAULT_EMBEDDING_MODEL"):
_config_factory.default_embedding_model(config_json.get("DEFAULT_EMBEDDING_MODEL"))
if config_json.get("Agent_MODEL"):
_config_factory.agent_model(config_json.get("Agent_MODEL"))
if config_json.get("HISTORY_LEN"):
_config_factory.history_len(config_json.get("HISTORY_LEN"))
if config_json.get("MAX_TOKENS"):
_config_factory.max_tokens(config_json.get("MAX_TOKENS"))
if config_json.get("TEMPERATURE"):
_config_factory.temperature(config_json.get("TEMPERATURE"))
if config_json.get("SUPPORT_AGENT_MODELS"):
_config_factory.support_agent_models(config_json.get("SUPPORT_AGENT_MODELS"))
if config_json.get("MODEL_PROVIDERS_CFG_PATH_CONFIG"):
_config_factory.model_providers_cfg_path_config(config_json.get("MODEL_PROVIDERS_CFG_PATH_CONFIG"))
if config_json.get("MODEL_PROVIDERS_CFG_HOST"):
_config_factory.model_providers_cfg_host(config_json.get("MODEL_PROVIDERS_CFG_HOST"))
if config_json.get("MODEL_PROVIDERS_CFG_PORT"):
_config_factory.model_providers_cfg_port(config_json.get("MODEL_PROVIDERS_CFG_PORT"))
if config_json.get("MODEL_PLATFORMS"):
_config_factory.model_platforms(config_json.get("MODEL_PLATFORMS"))
if config_json.get("TOOL_CONFIG"):
_config_factory.tool_config(config_json.get("TOOL_CONFIG"))
return _config_factory
@classmethod
def get_type(cls) -> str:
return ConfigModel.class_name()
def get_config(self) -> ConfigModel:
return self._config_factory.get_config()
def set_default_llm_model(self, llm_model: str):
self._config_factory.default_llm_model(llm_model)
self.store_config()
def set_default_embedding_model(self, embedding_model: str):
self._config_factory.default_embedding_model(embedding_model)
self.store_config()
def set_agent_model(self, agent_model: str):
self._config_factory.agent_model(agent_model)
self.store_config()
def set_history_len(self, history_len: int):
self._config_factory.history_len(history_len)
self.store_config()
def set_max_tokens(self, max_tokens: int):
self._config_factory.max_tokens(max_tokens)
self.store_config()
def set_temperature(self, temperature: float):
self._config_factory.temperature(temperature)
self.store_config()
def set_support_agent_models(self, support_agent_models: List[str]):
self._config_factory.support_agent_models(support_agent_models)
self.store_config()
def set_model_providers_cfg_path_config(self, model_providers_cfg_path_config: str):
self._config_factory.model_providers_cfg_path_config(model_providers_cfg_path_config)
self.store_config()
def set_model_providers_cfg_host(self, model_providers_cfg_host: str):
self._config_factory.model_providers_cfg_host(model_providers_cfg_host)
self.store_config()
def set_model_providers_cfg_port(self, model_providers_cfg_port: int):
self._config_factory.model_providers_cfg_port(model_providers_cfg_port)
self.store_config()
def set_model_platforms(self, model_platforms: List[Dict[str, Any]]):
self._config_factory.model_platforms(model_platforms=model_platforms)
self.store_config()
def set_tool_config(self, tool_config: Dict[str, Any]):
self._config_factory.tool_config(tool_config=tool_config)
self.store_config()
config_model_workspace: ConfigModelWorkSpace = ConfigModelWorkSpace()