from textual.screen import Screen from textual.widgets import Static, Input, Button, Checkbox, Label, Select from textual.containers import Vertical, VerticalScroll, Horizontal, Grid from agentm.components.footer import AgentMFooter from agentm.utils.logger import log_with_caller from agentm.theme.palette import get_theme from pathlib import Path palette = get_theme() HEADER_PATH = Path("agentm/assets/headers/training_setup.txt") class TrainingView(Screen): BINDINGS = [("escape", "app.pop_screen", "Back")] def __init__(self, agent: dict, model: dict | None, run: dict): super().__init__() self.agent_metadata = agent self.model_metadata = model self.run_metadata = run log_with_caller("info", f"Initialized TrainingView for agent: {agent['name']}") self.model_name_input = Input(id="model_name") self.num_env_input = Input(id="num_env") self.framework_dropdown = Select( id="framework", options=[("SB3", "SB3"), ("RAY", "RAY"), ("SHEEPRL", "SHEEPRL")], value="SB3" ) self.algo_dropdown = Select(id="algo", options=[("PPO", "PPO")], value="PPO") self.settings_inputs = { "step_ratio": Input(value="6", id="step_ratio"), "frame_shape": Input(value="[128, 128, 1]", id="frame_shape"), "continue_game": Input(value="0.0", id="continue_game"), "action_space": Input(value="discrete", id="action_space") } self.wrapper_checkboxes = { "normalize_reward": Checkbox(label="Normalize Reward", value=True, id="normalize_reward"), "no_attack_buttons_combinations": Checkbox(label="No Attack Combos", value=True, id="no_attack_buttons_combinations"), "add_last_action": Checkbox(label="Add Last Action", value=True, id="add_last_action"), "scale": Checkbox(label="Scale", value=True, id="scale"), "exclude_image_scaling": Checkbox(label="Exclude Image Scaling", value=True, id="exclude_image_scaling"), "role_relative": Checkbox(label="Role Relative", value=True, id="role_relative"), "flatten": Checkbox(label="Flatten", value=True, id="flatten") } self.wrapper_inputs = { "stack_frames": Input(value="4", id="stack_frames"), "stack_actions": Input(value="12", id="stack_actions"), "dilation": Input(value="1", id="dilation") } self.filter_key_checkboxes = [ Checkbox(label=key, value=True, id=f"fk_{key}") for key in ["frame", "action", "stage", "timer"] ] self.net_arch_input = Input(value="[64, 64]", id="net_arch") self.ppo_inputs = { "gamma": Input(value="0.94", id="gamma"), "model_checkpoint": Input(value="0", id="model_checkpoint"), "learning_rate_start": Input(value="0.00025", id="lr_start"), "learning_rate_end": Input(value="0.0000025", id="lr_end"), "clip_range_start": Input(value="0.15", id="clip_start"), "clip_range_end": Input(value="0.025", id="clip_end"), "batch_size": Input(value="256", id="batch_size"), "n_epochs": Input(value="4", id="n_epochs"), "n_steps": Input(value="128", id="n_steps"), "autosave_freq": Input(value="256", id="autosave_freq"), "time_steps": Input(value="512", id="time_steps") } def compose(self): log_with_caller("debug", "Composing TrainingView layout...") try: header_text = HEADER_PATH.read_text() except FileNotFoundError: log_with_caller("error", f"Missing header file at: {HEADER_PATH}") header_text = "[bold red]Missing Header File" yield Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header") yield VerticalScroll( Vertical( # Framework & Algorithm Static("Framework & Algorithm", classes="section_label"), Grid( Vertical(Label("Framework"), self.framework_dropdown, classes="input_cell"), Vertical(Label("Algorithm"), self.algo_dropdown, classes="input_cell"), classes="input_grid" ), # Model name / num_env Static("Model Configuration", classes="section_label"), Grid( Vertical(Label("Model Name"), self.model_name_input, classes="input_cell"), Vertical(Label("Num Envs"), self.num_env_input, classes="input_cell"), classes="input_grid" ), # General Settings Static("General Settings", classes="section_label"), Grid( *[ Vertical(Label(k.replace("_", " ").title()), v, classes="input_cell") for k, v in self.settings_inputs.items() ], classes="input_grid" ), # Wrapper Checkboxes Static("Wrapper Flags", classes="section_label"), Grid( *[ Vertical(cb, classes="input_cell") for cb in self.wrapper_checkboxes.values() ], classes="input_grid" ), # Wrapper Inputs Grid( *[ Vertical(Label(k.replace("_", " ").title()), v, classes="input_cell") for k, v in self.wrapper_inputs.items() ], classes="input_grid" ), # Filter Keys Static("Filter Keys", classes="section_label"), Grid(*self.filter_key_checkboxes, id="filter_grid"), # Policy kwargs Static("Policy Architecture", classes="section_label"), Grid( Vertical(Label("Net Arch"), self.net_arch_input, classes="input_cell"), classes="input_grid" ), # PPO Settings Static("PPO Settings", classes="section_label"), Grid( *[ Vertical(Label(k.replace("_", " ").title()), v, classes="input_cell") for k, v in self.ppo_inputs.items() ], classes="input_grid" ), id="form_layout", classes="form_column" ), id="config_scroll_container" ) yield Button("🔍 Review Config", id="review_config_btn", classes="confirm_button") yield AgentMFooter(compact=True) log_with_caller("debug", "Finished composing TrainingView")