mirror of
https://github.com/RYDE-WORK/Langchain-Chatchat.git
synced 2026-02-04 13:43:12 +08:00
ConfigModel单元测试
This commit is contained in:
parent
cd01bb8601
commit
6b6d09a123
@ -62,15 +62,13 @@ class ConfigWorkSpace(Generic[CF, F], ABC):
|
||||
self.workspace_config = os.path.join(self.workspace, "workspace_config.json")
|
||||
# 初始化工作空间配置,转换成json格式,实现Config的实例化
|
||||
|
||||
config_type_json = self._load_config()
|
||||
if config_type_json is None:
|
||||
_load_config = self._load_config()
|
||||
if _load_config is None:
|
||||
self._config_factory = self._build_config_factory(config_json={})
|
||||
self.store_config()
|
||||
|
||||
else:
|
||||
config_type = config_type_json.get("type", None)
|
||||
if self.get_type() != config_type:
|
||||
raise ValueError(f"Config type mismatch: {self.get_type()} != {config_type}")
|
||||
config_type_json = self.get_config_by_type(self.get_type())
|
||||
|
||||
config_json = config_type_json.get("config")
|
||||
self._config_factory = self._build_config_factory(config_json)
|
||||
@ -98,9 +96,39 @@ class ConfigWorkSpace(Generic[CF, F], ABC):
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_store_cfg_index_by_type(store_cfg, store_cfg_type) -> int:
|
||||
if store_cfg is None:
|
||||
raise RuntimeError("store_cfg is None.")
|
||||
for cfg in store_cfg:
|
||||
if cfg.get("type") == store_cfg_type:
|
||||
return store_cfg.index(cfg)
|
||||
|
||||
return -1
|
||||
|
||||
def get_config_by_type(self, cfg_type) -> Dict[str, Any]:
|
||||
store_cfg = self._load_config()
|
||||
if store_cfg is None:
|
||||
raise RuntimeError("store_cfg is None.")
|
||||
|
||||
get_lambda = lambda store_cfg_type: store_cfg[self._get_store_cfg_index_by_type(store_cfg, store_cfg_type)]
|
||||
return get_lambda(cfg_type)
|
||||
|
||||
def store_config(self):
|
||||
logger.info("Store workspace config.")
|
||||
_load_config = self._load_config()
|
||||
with open(self.workspace_config, "w") as f:
|
||||
config_json = self.get_config().to_dict()
|
||||
|
||||
if _load_config is None:
|
||||
_load_config = []
|
||||
config_json_index = self._get_store_cfg_index_by_type(
|
||||
store_cfg=_load_config,
|
||||
store_cfg_type=self.get_type()
|
||||
)
|
||||
config_type_json = {"type": self.get_type(), "config": config_json}
|
||||
f.write(json.dumps(config_type_json, indent=4, ensure_ascii=False))
|
||||
if config_json_index == -1:
|
||||
_load_config.append(config_type_json)
|
||||
else:
|
||||
_load_config[config_json_index] = config_type_json
|
||||
f.write(json.dumps(_load_config, indent=4, ensure_ascii=False))
|
||||
|
||||
@ -3,7 +3,9 @@ from pathlib import Path
|
||||
from chatchat.configs import (
|
||||
ConfigBasicFactory,
|
||||
ConfigBasic,
|
||||
ConfigBasicWorkSpace
|
||||
ConfigBasicWorkSpace,
|
||||
ConfigModelWorkSpace,
|
||||
ConfigModel
|
||||
)
|
||||
import os
|
||||
|
||||
@ -43,3 +45,53 @@ def test_workspace_default():
|
||||
|
||||
|
||||
def test_config_model_workspace():
|
||||
|
||||
config_model_workspace: ConfigModelWorkSpace = ConfigModelWorkSpace()
|
||||
|
||||
assert config_model_workspace.get_config() is not None
|
||||
|
||||
config_model_workspace.set_default_llm_model(llm_model="glm4")
|
||||
config_model_workspace.set_default_embedding_model(embedding_model="text1")
|
||||
config_model_workspace.set_agent_model(agent_model="agent")
|
||||
config_model_workspace.set_history_len(history_len=1)
|
||||
config_model_workspace.set_max_tokens(max_tokens=1000)
|
||||
config_model_workspace.set_temperature(temperature=0.1)
|
||||
config_model_workspace.set_support_agent_models(support_agent_models=["glm4"])
|
||||
config_model_workspace.set_model_providers_cfg_path_config(model_providers_cfg_path_config="model_providers.yaml")
|
||||
config_model_workspace.set_model_providers_cfg_host(model_providers_cfg_host="127.0.0.1")
|
||||
config_model_workspace.set_model_providers_cfg_port(model_providers_cfg_port=8000)
|
||||
|
||||
config: ConfigModel = config_model_workspace.get_config()
|
||||
|
||||
assert config.DEFAULT_LLM_MODEL == "glm4"
|
||||
assert config.DEFAULT_EMBEDDING_MODEL == "text1"
|
||||
assert config.Agent_MODEL == "agent"
|
||||
assert config.HISTORY_LEN == 1
|
||||
assert config.MAX_TOKENS == 1000
|
||||
assert config.TEMPERATURE == 0.1
|
||||
assert config.SUPPORT_AGENT_MODELS == ["glm4"]
|
||||
assert config.MODEL_PROVIDERS_CFG_PATH_CONFIG == "model_providers.yaml"
|
||||
assert config.MODEL_PROVIDERS_CFG_HOST == "127.0.0.1"
|
||||
assert config.MODEL_PROVIDERS_CFG_PORT == 8000
|
||||
config_model_workspace.clear()
|
||||
|
||||
|
||||
def test_model_config():
|
||||
from chatchat.configs import (
|
||||
DEFAULT_LLM_MODEL, DEFAULT_EMBEDDING_MODEL, Agent_MODEL, HISTORY_LEN, MAX_TOKENS, TEMPERATURE,
|
||||
SUPPORT_AGENT_MODELS, MODEL_PROVIDERS_CFG_PATH_CONFIG, MODEL_PROVIDERS_CFG_HOST, MODEL_PROVIDERS_CFG_PORT,
|
||||
TOOL_CONFIG, MODEL_PLATFORMS, LLM_MODEL_CONFIG
|
||||
)
|
||||
assert DEFAULT_LLM_MODEL is not None
|
||||
assert DEFAULT_EMBEDDING_MODEL is not None
|
||||
assert Agent_MODEL is None
|
||||
assert HISTORY_LEN is not None
|
||||
assert MAX_TOKENS is None
|
||||
assert TEMPERATURE is not None
|
||||
assert SUPPORT_AGENT_MODELS is not None
|
||||
assert MODEL_PROVIDERS_CFG_PATH_CONFIG is not None
|
||||
assert MODEL_PROVIDERS_CFG_HOST is not None
|
||||
assert MODEL_PROVIDERS_CFG_PORT is not None
|
||||
assert TOOL_CONFIG is not None
|
||||
assert MODEL_PLATFORMS is not None
|
||||
assert LLM_MODEL_CONFIG is not None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user