agent_m/agentm/views/training.py

153 lines
6.3 KiB
Python

from textual.screen import Screen
from textual.widgets import Static, Input, Button, Checkbox, Label, Select
from textual.containers import Vertical, VerticalScroll, Horizontal, Grid
from agentm.components.footer import AgentMFooter
from agentm.utils.logger import log_with_caller
from agentm.theme.palette import get_theme
from pathlib import Path
palette = get_theme()
HEADER_PATH = Path("agentm/assets/headers/training_setup.txt")
class TrainingView(Screen):
BINDINGS = [("escape", "app.pop_screen", "Back")]
def __init__(self, agent: dict, model: dict | None, run: dict):
super().__init__()
self.agent_metadata = agent
self.model_metadata = model
self.run_metadata = run
log_with_caller("info", f"Initialized TrainingView for agent: {agent['name']}")
self.model_name_input = Input(id="model_name")
self.num_env_input = Input(id="num_env")
self.framework_dropdown = Select(
id="framework",
options=[("SB3", "SB3"), ("RAY", "RAY"), ("SHEEPRL", "SHEEPRL")],
value="SB3"
)
self.algo_dropdown = Select(id="algo", options=[("PPO", "PPO")], value="PPO")
self.settings_inputs = {
"step_ratio": Input(value="6", id="step_ratio"),
"frame_shape": Input(value="[128, 128, 1]", id="frame_shape"),
"continue_game": Input(value="0.0", id="continue_game"),
"action_space": Input(value="discrete", id="action_space")
}
self.wrapper_checkboxes = {
"normalize_reward": Checkbox(label="Normalize Reward", value=True, id="normalize_reward"),
"no_attack_buttons_combinations": Checkbox(label="No Attack Combos", value=True, id="no_attack_buttons_combinations"),
"add_last_action": Checkbox(label="Add Last Action", value=True, id="add_last_action"),
"scale": Checkbox(label="Scale", value=True, id="scale"),
"exclude_image_scaling": Checkbox(label="Exclude Image Scaling", value=True, id="exclude_image_scaling"),
"role_relative": Checkbox(label="Role Relative", value=True, id="role_relative"),
"flatten": Checkbox(label="Flatten", value=True, id="flatten")
}
self.wrapper_inputs = {
"stack_frames": Input(value="4", id="stack_frames"),
"stack_actions": Input(value="12", id="stack_actions"),
"dilation": Input(value="1", id="dilation")
}
self.filter_key_checkboxes = [
Checkbox(label=key, value=True, id=f"fk_{key}") for key in ["frame", "action", "stage", "timer"]
]
self.net_arch_input = Input(value="[64, 64]", id="net_arch")
self.ppo_inputs = {
"gamma": Input(value="0.94", id="gamma"),
"model_checkpoint": Input(value="0", id="model_checkpoint"),
"learning_rate_start": Input(value="0.00025", id="lr_start"),
"learning_rate_end": Input(value="0.0000025", id="lr_end"),
"clip_range_start": Input(value="0.15", id="clip_start"),
"clip_range_end": Input(value="0.025", id="clip_end"),
"batch_size": Input(value="256", id="batch_size"),
"n_epochs": Input(value="4", id="n_epochs"),
"n_steps": Input(value="128", id="n_steps"),
"autosave_freq": Input(value="256", id="autosave_freq"),
"time_steps": Input(value="512", id="time_steps")
}
def compose(self):
log_with_caller("debug", "Composing TrainingView layout...")
try:
header_text = HEADER_PATH.read_text()
except FileNotFoundError:
log_with_caller("error", f"Missing header file at: {HEADER_PATH}")
header_text = "[bold red]Missing Header File"
yield Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header")
# Build scrollable form content
yield VerticalScroll(
Vertical(
# Framework & Algorithm
Static("Framework & Algorithm", classes="section_label"),
Horizontal(
Vertical(Label("Framework"), self.framework_dropdown),
Vertical(Label("Algorithm"), self.algo_dropdown),
classes="input_row"
),
# Model name / num_env
Static("Model Configuration", classes="section_label"),
Horizontal(
Vertical(Label("Model Name"), self.model_name_input),
Vertical(Label("Num Envs"), self.num_env_input),
classes="input_row"
),
# General Settings
Static("General Settings", classes="section_label"),
*[
Horizontal(Vertical(Label(k)), v, classes="input_row")
for k, v in self.settings_inputs.items()
],
# Wrapper Checkboxes
Static("Wrapper Flags", classes="section_label"),
*[
Horizontal(cb, classes="input_row")
for cb in self.wrapper_checkboxes.values()
],
# Wrapper Inputs
*[
Horizontal(Vertical(Label(k.replace("_", " ").title())), v, classes="input_row")
for k, v in self.wrapper_inputs.items()
],
# Filter Keys
Static("Filter Keys", classes="section_label"),
Grid(*self.filter_key_checkboxes, id="filter_grid"),
# Policy kwargs
Static("Policy Architecture", classes="section_label"),
Horizontal(Label("Net Arch"), self.net_arch_input, classes="input_row"),
# PPO Settings
Static("PPO Settings", classes="section_label"),
*[
Horizontal(Label(k.replace("_", " ").title()), v, classes="input_row")
for k, v in self.ppo_inputs.items()
],
id="form_layout",
classes="form_column"
),
id="config_scroll_container"
)
# Final button and footer
yield Button("🔍 Review Config", id="review_config_btn", classes="confirm_button")
yield AgentMFooter(compact=True)
log_with_caller("debug", "Finished composing TrainingView")