ConfigModel单元测试

This commit is contained in:
glide-the 2024-06-11 14:39:25 +08:00
parent cd01bb8601
commit 6b6d09a123
2 changed files with 87 additions and 7 deletions

View File

@ -62,15 +62,13 @@ class ConfigWorkSpace(Generic[CF, F], ABC):
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
# 初始化工作空间配置转换成json格式实现Config的实例化
config_type_json = self._load_config()
if config_type_json is None:
_load_config = self._load_config()
if _load_config is None:
self._config_factory = self._build_config_factory(config_json={})
self.store_config()
else:
config_type = config_type_json.get("type", None)
if self.get_type() != config_type:
raise ValueError(f"Config type mismatch: {self.get_type()} != {config_type}")
config_type_json = self.get_config_by_type(self.get_type())
config_json = config_type_json.get("config")
self._config_factory = self._build_config_factory(config_json)
@ -98,9 +96,39 @@ class ConfigWorkSpace(Generic[CF, F], ABC):
except FileNotFoundError:
return None
@staticmethod
def _get_store_cfg_index_by_type(store_cfg, store_cfg_type) -> int:
if store_cfg is None:
raise RuntimeError("store_cfg is None.")
for cfg in store_cfg:
if cfg.get("type") == store_cfg_type:
return store_cfg.index(cfg)
return -1
def get_config_by_type(self, cfg_type) -> Dict[str, Any]:
store_cfg = self._load_config()
if store_cfg is None:
raise RuntimeError("store_cfg is None.")
get_lambda = lambda store_cfg_type: store_cfg[self._get_store_cfg_index_by_type(store_cfg, store_cfg_type)]
return get_lambda(cfg_type)
def store_config(self):
logger.info("Store workspace config.")
_load_config = self._load_config()
with open(self.workspace_config, "w") as f:
config_json = self.get_config().to_dict()
if _load_config is None:
_load_config = []
config_json_index = self._get_store_cfg_index_by_type(
store_cfg=_load_config,
store_cfg_type=self.get_type()
)
config_type_json = {"type": self.get_type(), "config": config_json}
f.write(json.dumps(config_type_json, indent=4, ensure_ascii=False))
if config_json_index == -1:
_load_config.append(config_type_json)
else:
_load_config[config_json_index] = config_type_json
f.write(json.dumps(_load_config, indent=4, ensure_ascii=False))

View File

@ -3,7 +3,9 @@ from pathlib import Path
from chatchat.configs import (
ConfigBasicFactory,
ConfigBasic,
ConfigBasicWorkSpace
ConfigBasicWorkSpace,
ConfigModelWorkSpace,
ConfigModel
)
import os
@ -43,3 +45,53 @@ def test_workspace_default():
def test_config_model_workspace():
config_model_workspace: ConfigModelWorkSpace = ConfigModelWorkSpace()
assert config_model_workspace.get_config() is not None
config_model_workspace.set_default_llm_model(llm_model="glm4")
config_model_workspace.set_default_embedding_model(embedding_model="text1")
config_model_workspace.set_agent_model(agent_model="agent")
config_model_workspace.set_history_len(history_len=1)
config_model_workspace.set_max_tokens(max_tokens=1000)
config_model_workspace.set_temperature(temperature=0.1)
config_model_workspace.set_support_agent_models(support_agent_models=["glm4"])
config_model_workspace.set_model_providers_cfg_path_config(model_providers_cfg_path_config="model_providers.yaml")
config_model_workspace.set_model_providers_cfg_host(model_providers_cfg_host="127.0.0.1")
config_model_workspace.set_model_providers_cfg_port(model_providers_cfg_port=8000)
config: ConfigModel = config_model_workspace.get_config()
assert config.DEFAULT_LLM_MODEL == "glm4"
assert config.DEFAULT_EMBEDDING_MODEL == "text1"
assert config.Agent_MODEL == "agent"
assert config.HISTORY_LEN == 1
assert config.MAX_TOKENS == 1000
assert config.TEMPERATURE == 0.1
assert config.SUPPORT_AGENT_MODELS == ["glm4"]
assert config.MODEL_PROVIDERS_CFG_PATH_CONFIG == "model_providers.yaml"
assert config.MODEL_PROVIDERS_CFG_HOST == "127.0.0.1"
assert config.MODEL_PROVIDERS_CFG_PORT == 8000
config_model_workspace.clear()
def test_model_config():
from chatchat.configs import (
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, Agent_MODEL, HISTORY_LEN, MAX_TOKENS, TEMPERATURE,
SUPPORT_AGENT_MODELS, MODEL_PROVIDERS_CFG_PATH_CONFIG, MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT,
TOOL_CONFIG, MODEL_PLATFORMS, LLM_MODEL_CONFIG
)
assert DEFAULT_LLM_MODEL is not None
assert DEFAULT_EMBEDDING_MODEL is not None
assert Agent_MODEL is None
assert HISTORY_LEN is not None
assert MAX_TOKENS is None
assert TEMPERATURE is not None
assert SUPPORT_AGENT_MODELS is not None
assert MODEL_PROVIDERS_CFG_PATH_CONFIG is not None
assert MODEL_PROVIDERS_CFG_HOST is not None
assert MODEL_PROVIDERS_CFG_PORT is not None
assert TOOL_CONFIG is not None
assert MODEL_PLATFORMS is not None
assert LLM_MODEL_CONFIG is not None