配置中心agent信息 工具配置

This commit is contained in:
glide-the 2024-06-12 18:05:43 +08:00
parent 04590a2fd5
commit 52a119fd49
2 changed files with 74 additions and 47 deletions

View File

@ -50,6 +50,8 @@ def basic(**kwargs):
@click.option("--model_providers_cfg_path_config", help="模型平台配置文件路径")
@click.option("--model_providers_cfg_host", help="模型平台配置服务host")
@click.option("--model_providers_cfg_port", type=int, help="模型平台配置服务port")
@click.option("--set_model_platforms", type=str, help="模型平台配置")
@click.option("--set_tool_config", type=str, help="工具配置项 ")
@click.option("--clear", is_flag=True, help="清除配置")
@click.option("--show", is_flag=True, help="显示配置")
def model(**kwargs):
@ -85,6 +87,13 @@ def model(**kwargs):
if kwargs["model_providers_cfg_port"]:
config_model_workspace.set_model_providers_cfg_port(model_providers_cfg_port=kwargs["model_providers_cfg_port"])
if kwargs["set_model_platforms"]:
model_platforms_dict = json.loads(kwargs["set_model_platforms"])
config_model_workspace.set_model_platforms(model_platforms=model_platforms_dict)
if kwargs["set_tool_config"]:
tool_config_dict = json.loads(kwargs["set_tool_config"])
config_model_workspace.set_tool_config(tool_config=tool_config_dict)
if kwargs["clear"]:
config_model_workspace.clear()
if kwargs["show"]:

View File

@ -86,53 +86,6 @@ class ConfigModelFactory(core_config.ConfigFactory[ConfigModel]):
self.MODEL_PROVIDERS_CFG_PORT = 20000
self._init_llm_work_config()
def _init_llm_work_config(self):
"""初始化知识库runtime的一些配置"""
self.LLM_MODEL_CONFIG = {
# 意图识别不需要输出,模型后台知道就行
"preprocess_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.05,
"max_tokens": 4096,
"history_len": 100,
"prompt_name": "default",
"callbacks": False
},
},
"llm_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.9,
"max_tokens": 4096,
"history_len": 10,
"prompt_name": "default",
"callbacks": True
},
},
"action_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.01,
"max_tokens": 4096,
"prompt_name": "ChatGLM3",
"callbacks": True
},
},
"postprocess_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.01,
"max_tokens": 4096,
"prompt_name": "default",
"callbacks": True
}
},
"image_model": {
"sd-turbo": {
"size": "256*256",
}
}
}
# 可以通过 model_providers 提供转换不同平台的接口为openai endpoint的能力启动后下面变量会自动增加相应的平台
# ### 如果您已经有了一个openai endpoint的能力的地址可以在这里直接配置
@ -314,6 +267,53 @@ class ConfigModelFactory(core_config.ConfigFactory[ConfigModel]):
}
},
}
self._init_llm_work_config()
def _init_llm_work_config(self):
"""初始化知识库runtime的一些配置"""
self.LLM_MODEL_CONFIG = {
# 意图识别不需要输出,模型后台知道就行
"preprocess_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.05,
"max_tokens": 4096,
"history_len": 100,
"prompt_name": "default",
"callbacks": False
},
},
"llm_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.9,
"max_tokens": 4096,
"history_len": 10,
"prompt_name": "default",
"callbacks": True
},
},
"action_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.01,
"max_tokens": 4096,
"prompt_name": "ChatGLM3",
"callbacks": True
},
},
"postprocess_model": {
self.DEFAULT_LLM_MODEL: {
"temperature": 0.01,
"max_tokens": 4096,
"prompt_name": "default",
"callbacks": True
}
},
"image_model": {
"sd-turbo": {
"size": "256*256",
}
}
}
def default_llm_model(self, llm_model: str):
self.DEFAULT_LLM_MODEL = llm_model
@ -345,6 +345,12 @@ class ConfigModelFactory(core_config.ConfigFactory[ConfigModel]):
def model_providers_cfg_port(self, model_providers_cfg_port: int):
self.MODEL_PROVIDERS_CFG_PORT = model_providers_cfg_port
def model_platforms(self, model_platforms: List[Dict[str, Any]]):
self.MODEL_PLATFORMS = model_platforms
def tool_config(self, tool_config: Dict[str, Any]):
self.TOOL_CONFIG = tool_config
def get_config(self) -> ConfigModel:
config = ConfigModel()
config.DEFAULT_LLM_MODEL = self.DEFAULT_LLM_MODEL
@ -396,6 +402,10 @@ class ConfigModelWorkSpace(core_config.ConfigWorkSpace[ConfigModelFactory, Confi
_config_factory.model_providers_cfg_host(config_json.get("MODEL_PROVIDERS_CFG_HOST"))
if config_json.get("MODEL_PROVIDERS_CFG_PORT"):
_config_factory.model_providers_cfg_port(config_json.get("MODEL_PROVIDERS_CFG_PORT"))
if config_json.get("MODEL_PLATFORMS"):
_config_factory.model_platforms(config_json.get("MODEL_PLATFORMS"))
if config_json.get("TOOL_CONFIG"):
_config_factory.tool_config(config_json.get("TOOL_CONFIG"))
return _config_factory
@ -446,5 +456,13 @@ class ConfigModelWorkSpace(core_config.ConfigWorkSpace[ConfigModelFactory, Confi
self._config_factory.model_providers_cfg_port(model_providers_cfg_port)
self.store_config()
def set_model_platforms(self, model_platforms: List[Dict[str, Any]]):
self._config_factory.model_platforms(model_platforms=model_platforms)
self.store_config()
def set_tool_config(self, tool_config: Dict[str, Any]):
self._config_factory.tool_config(tool_config=tool_config)
self.store_config()
config_model_workspace: ConfigModelWorkSpace = ConfigModelWorkSpace()