agent_m/agentm/logic/config_manager.py

83 lines
2.9 KiB
Python

import yaml
from pathlib import Path
from typing import Optional
class ConfigManager:
def __init__(self, game_id: str, model_name: str = "default_model"):
self.game_id = game_id
self.model_name = model_name
self.config_base_path = Path("agentm/data/game_data")
self.general = self._load_yaml("general.yaml")
self.game_specific = self._load_yaml(f"{self.game_id}.yaml")
self.config = self._build_merged_config()
def _load_yaml(self, filename: str) -> dict:
path = self.config_base_path / filename
if not path.exists():
return {}
with open(path, "r") as f:
return yaml.safe_load(f)
def _deep_merge(self, source: dict, overrides: dict) -> dict:
for key, value in overrides.items():
if key in source:
if isinstance(source[key], dict) and isinstance(value, dict):
self._deep_merge(source[key], value)
elif isinstance(source[key], list) and isinstance(value, list):
source[key] = list(dict.fromkeys(source[key] + value))
else:
source[key] = value
else:
source[key] = value
return source
def _build_merged_config(self) -> dict:
folders = self.general.get("folders", {}).copy()
folders["model_name"] = self.model_name
settings = self.general.get("settings", {}).copy()
settings["game_id"] = self.game_id
settings = self._deep_merge(settings, self.game_specific.get("settings", {}))
if "frame_shape" in settings and isinstance(settings["frame_shape"], list):
settings["frame_shape"] = tuple(settings["frame_shape"])
wrappers_settings = self._deep_merge(
self.general.get("wrappers_settings", {}).copy(),
self.game_specific.get("wrappers_settings", {})
)
policy_kwargs = self._deep_merge(
self.general.get("policy_kwargs", {}).copy(),
self.game_specific.get("policy_kwargs", {})
)
ppo_settings = self._deep_merge(
self.general.get("ppo_settings", {}).copy(),
self.game_specific.get("ppo_settings", {})
)
return {
"folders": folders,
"settings": settings,
"wrappers_settings": wrappers_settings,
"policy_kwargs": policy_kwargs,
"ppo_settings": ppo_settings,
}
def get(self) -> dict:
"""Return the full merged config dictionary."""
return self.config
def dump_to_file(self, out_path: Path):
"""Write the config to a YAML file."""
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "w") as f:
yaml.dump(self.config, f, sort_keys=False)
def dump_to_string(self) -> str:
"""Return the YAML string representation of the config."""
return yaml.dump(self.config, sort_keys=False)