diff --git a/libs/chatchat-server/chatchat/configs/_core_config.py b/libs/chatchat-server/chatchat/configs/_core_config.py index 5059ff3a..339c3dfa 100644 --- a/libs/chatchat-server/chatchat/configs/_core_config.py +++ b/libs/chatchat-server/chatchat/configs/_core_config.py @@ -62,15 +62,13 @@ class ConfigWorkSpace(Generic[CF, F], ABC): self.workspace_config = os.path.join(self.workspace, "workspace_config.json") # 初始化工作空间配置,转换成json格式,实现Config的实例化 - config_type_json = self._load_config() - if config_type_json is None: + _load_config = self._load_config() + if _load_config is None: self._config_factory = self._build_config_factory(config_json={}) self.store_config() else: - config_type = config_type_json.get("type", None) - if self.get_type() != config_type: - raise ValueError(f"Config type mismatch: {self.get_type()} != {config_type}") + config_type_json = self.get_config_by_type(self.get_type()) config_json = config_type_json.get("config") self._config_factory = self._build_config_factory(config_json) @@ -98,9 +96,39 @@ class ConfigWorkSpace(Generic[CF, F], ABC): except FileNotFoundError: return None + @staticmethod + def _get_store_cfg_index_by_type(store_cfg, store_cfg_type) -> int: + if store_cfg is None: + raise RuntimeError("store_cfg is None.") + for cfg in store_cfg: + if cfg.get("type") == store_cfg_type: + return store_cfg.index(cfg) + + return -1 + + def get_config_by_type(self, cfg_type) -> Dict[str, Any]: + store_cfg = self._load_config() + if store_cfg is None: + raise RuntimeError("store_cfg is None.") + + get_lambda = lambda store_cfg_type: store_cfg[self._get_store_cfg_index_by_type(store_cfg, store_cfg_type)] + return get_lambda(cfg_type) + def store_config(self): logger.info("Store workspace config.") + _load_config = self._load_config() with open(self.workspace_config, "w") as f: config_json = self.get_config().to_dict() + + if _load_config is None: + _load_config = [] + config_json_index = self._get_store_cfg_index_by_type( + store_cfg=_load_config, + store_cfg_type=self.get_type() + ) config_type_json = {"type": self.get_type(), "config": config_json} - f.write(json.dumps(config_type_json, indent=4, ensure_ascii=False)) + if config_json_index == -1: + _load_config.append(config_type_json) + else: + _load_config[config_json_index] = config_type_json + f.write(json.dumps(_load_config, indent=4, ensure_ascii=False)) diff --git a/libs/chatchat-server/tests/unit_tests/config/test_config.py b/libs/chatchat-server/tests/unit_tests/config/test_config.py index cb272155..d540e5f6 100644 --- a/libs/chatchat-server/tests/unit_tests/config/test_config.py +++ b/libs/chatchat-server/tests/unit_tests/config/test_config.py @@ -3,7 +3,9 @@ from pathlib import Path from chatchat.configs import ( ConfigBasicFactory, ConfigBasic, - ConfigBasicWorkSpace + ConfigBasicWorkSpace, + ConfigModelWorkSpace, + ConfigModel ) import os @@ -43,3 +45,53 @@ def test_workspace_default(): def test_config_model_workspace(): + + config_model_workspace: ConfigModelWorkSpace = ConfigModelWorkSpace() + + assert config_model_workspace.get_config() is not None + + config_model_workspace.set_default_llm_model(llm_model="glm4") + config_model_workspace.set_default_embedding_model(embedding_model="text1") + config_model_workspace.set_agent_model(agent_model="agent") + config_model_workspace.set_history_len(history_len=1) + config_model_workspace.set_max_tokens(max_tokens=1000) + config_model_workspace.set_temperature(temperature=0.1) + config_model_workspace.set_support_agent_models(support_agent_models=["glm4"]) + config_model_workspace.set_model_providers_cfg_path_config(model_providers_cfg_path_config="model_providers.yaml") + config_model_workspace.set_model_providers_cfg_host(model_providers_cfg_host="127.0.0.1") + config_model_workspace.set_model_providers_cfg_port(model_providers_cfg_port=8000) + + config: ConfigModel = config_model_workspace.get_config() + + assert config.DEFAULT_LLM_MODEL == "glm4" + assert config.DEFAULT_EMBEDDING_MODEL == "text1" + assert config.Agent_MODEL == "agent" + assert config.HISTORY_LEN == 1 + assert config.MAX_TOKENS == 1000 + assert config.TEMPERATURE == 0.1 + assert config.SUPPORT_AGENT_MODELS == ["glm4"] + assert config.MODEL_PROVIDERS_CFG_PATH_CONFIG == "model_providers.yaml" + assert config.MODEL_PROVIDERS_CFG_HOST == "127.0.0.1" + assert config.MODEL_PROVIDERS_CFG_PORT == 8000 + config_model_workspace.clear() + + +def test_model_config(): + from chatchat.configs import ( + DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, Agent_MODEL, HISTORY_LEN, MAX_TOKENS, TEMPERATURE, + SUPPORT_AGENT_MODELS, MODEL_PROVIDERS_CFG_PATH_CONFIG, MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT, + TOOL_CONFIG, MODEL_PLATFORMS, LLM_MODEL_CONFIG + ) + assert DEFAULT_LLM_MODEL is not None + assert DEFAULT_EMBEDDING_MODEL is not None + assert Agent_MODEL is None + assert HISTORY_LEN is not None + assert MAX_TOKENS is None + assert TEMPERATURE is not None + assert SUPPORT_AGENT_MODELS is not None + assert MODEL_PROVIDERS_CFG_PATH_CONFIG is not None + assert MODEL_PROVIDERS_CFG_HOST is not None + assert MODEL_PROVIDERS_CFG_PORT is not None + assert TOOL_CONFIG is not None + assert MODEL_PLATFORMS is not None + assert LLM_MODEL_CONFIG is not None