83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
import yaml
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
|
|
class ConfigManager:
|
|
def __init__(self, game_id: str, model_name: str = "default_model"):
|
|
self.game_id = game_id
|
|
self.model_name = model_name
|
|
self.config_base_path = Path("agentm/data/game_data")
|
|
|
|
self.general = self._load_yaml("general.yaml")
|
|
self.game_specific = self._load_yaml(f"{self.game_id}.yaml")
|
|
|
|
self.config = self._build_merged_config()
|
|
|
|
def _load_yaml(self, filename: str) -> dict:
|
|
path = self.config_base_path / filename
|
|
if not path.exists():
|
|
return {}
|
|
with open(path, "r") as f:
|
|
return yaml.safe_load(f)
|
|
|
|
def _deep_merge(self, source: dict, overrides: dict) -> dict:
|
|
for key, value in overrides.items():
|
|
if key in source:
|
|
if isinstance(source[key], dict) and isinstance(value, dict):
|
|
self._deep_merge(source[key], value)
|
|
elif isinstance(source[key], list) and isinstance(value, list):
|
|
source[key] = list(dict.fromkeys(source[key] + value))
|
|
else:
|
|
source[key] = value
|
|
else:
|
|
source[key] = value
|
|
return source
|
|
|
|
def _build_merged_config(self) -> dict:
|
|
folders = self.general.get("folders", {}).copy()
|
|
folders["model_name"] = self.model_name
|
|
|
|
settings = self.general.get("settings", {}).copy()
|
|
settings["game_id"] = self.game_id
|
|
settings = self._deep_merge(settings, self.game_specific.get("settings", {}))
|
|
if "frame_shape" in settings and isinstance(settings["frame_shape"], list):
|
|
settings["frame_shape"] = tuple(settings["frame_shape"])
|
|
|
|
wrappers_settings = self._deep_merge(
|
|
self.general.get("wrappers_settings", {}).copy(),
|
|
self.game_specific.get("wrappers_settings", {})
|
|
)
|
|
|
|
policy_kwargs = self._deep_merge(
|
|
self.general.get("policy_kwargs", {}).copy(),
|
|
self.game_specific.get("policy_kwargs", {})
|
|
)
|
|
|
|
ppo_settings = self._deep_merge(
|
|
self.general.get("ppo_settings", {}).copy(),
|
|
self.game_specific.get("ppo_settings", {})
|
|
)
|
|
|
|
return {
|
|
"folders": folders,
|
|
"settings": settings,
|
|
"wrappers_settings": wrappers_settings,
|
|
"policy_kwargs": policy_kwargs,
|
|
"ppo_settings": ppo_settings,
|
|
}
|
|
|
|
def get(self) -> dict:
|
|
"""Return the full merged config dictionary."""
|
|
return self.config
|
|
|
|
def dump_to_file(self, out_path: Path):
|
|
"""Write the config to a YAML file."""
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(out_path, "w") as f:
|
|
yaml.dump(self.config, f, sort_keys=False)
|
|
|
|
def dump_to_string(self) -> str:
|
|
"""Return the YAML string representation of the config."""
|
|
return yaml.dump(self.config, sort_keys=False)
|