166 lines
6.8 KiB
Python
166 lines
6.8 KiB
Python
from textual.screen import Screen
|
|
from textual.widgets import Static, Input, Button, Checkbox, Label, Select
|
|
from textual.containers import Vertical, VerticalScroll, Horizontal, Grid
|
|
from agentm.components.footer import AgentMFooter
|
|
from agentm.utils.logger import log_with_caller
|
|
from agentm.theme.palette import get_theme
|
|
from pathlib import Path
|
|
|
|
palette = get_theme()
|
|
HEADER_PATH = Path("agentm/assets/headers/training_setup.txt")
|
|
|
|
class TrainingView(Screen):
|
|
BINDINGS = [("escape", "app.pop_screen", "Back")]
|
|
|
|
def __init__(self, agent: dict, model: dict | None, run: dict):
|
|
super().__init__()
|
|
self.agent_metadata = agent
|
|
self.model_metadata = model
|
|
self.run_metadata = run
|
|
|
|
log_with_caller("info", f"Initialized TrainingView for agent: {agent['name']}")
|
|
|
|
self.model_name_input = Input(id="model_name")
|
|
self.num_env_input = Input(id="num_env")
|
|
|
|
self.framework_dropdown = Select(
|
|
id="framework",
|
|
options=[("SB3", "SB3"), ("RAY", "RAY"), ("SHEEPRL", "SHEEPRL")],
|
|
value="SB3"
|
|
)
|
|
self.algo_dropdown = Select(id="algo", options=[("PPO", "PPO")], value="PPO")
|
|
|
|
self.settings_inputs = {
|
|
"step_ratio": Input(value="6", id="step_ratio"),
|
|
"frame_shape": Input(value="[128, 128, 1]", id="frame_shape"),
|
|
"continue_game": Input(value="0.0", id="continue_game"),
|
|
"action_space": Input(value="discrete", id="action_space")
|
|
}
|
|
|
|
self.wrapper_checkboxes = {
|
|
"normalize_reward": Checkbox(label="Normalize Reward", value=True, id="normalize_reward"),
|
|
"no_attack_buttons_combinations": Checkbox(label="No Attack Combos", value=True, id="no_attack_buttons_combinations"),
|
|
"add_last_action": Checkbox(label="Add Last Action", value=True, id="add_last_action"),
|
|
"scale": Checkbox(label="Scale", value=True, id="scale"),
|
|
"exclude_image_scaling": Checkbox(label="Exclude Image Scaling", value=True, id="exclude_image_scaling"),
|
|
"role_relative": Checkbox(label="Role Relative", value=True, id="role_relative"),
|
|
"flatten": Checkbox(label="Flatten", value=True, id="flatten")
|
|
}
|
|
|
|
self.wrapper_inputs = {
|
|
"stack_frames": Input(value="4", id="stack_frames"),
|
|
"stack_actions": Input(value="12", id="stack_actions"),
|
|
"dilation": Input(value="1", id="dilation")
|
|
}
|
|
|
|
self.filter_key_checkboxes = [
|
|
Checkbox(label=key, value=True, id=f"fk_{key}") for key in ["frame", "action", "stage", "timer"]
|
|
]
|
|
|
|
self.net_arch_input = Input(value="[64, 64]", id="net_arch")
|
|
|
|
self.ppo_inputs = {
|
|
"gamma": Input(value="0.94", id="gamma"),
|
|
"model_checkpoint": Input(value="0", id="model_checkpoint"),
|
|
"learning_rate_start": Input(value="0.00025", id="lr_start"),
|
|
"learning_rate_end": Input(value="0.0000025", id="lr_end"),
|
|
"clip_range_start": Input(value="0.15", id="clip_start"),
|
|
"clip_range_end": Input(value="0.025", id="clip_end"),
|
|
"batch_size": Input(value="256", id="batch_size"),
|
|
"n_epochs": Input(value="4", id="n_epochs"),
|
|
"n_steps": Input(value="128", id="n_steps"),
|
|
"autosave_freq": Input(value="256", id="autosave_freq"),
|
|
"time_steps": Input(value="512", id="time_steps")
|
|
}
|
|
|
|
def compose(self):
|
|
log_with_caller("debug", "Composing TrainingView layout...")
|
|
|
|
try:
|
|
header_text = HEADER_PATH.read_text()
|
|
except FileNotFoundError:
|
|
log_with_caller("error", f"Missing header file at: {HEADER_PATH}")
|
|
header_text = "[bold red]Missing Header File"
|
|
|
|
yield Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header")
|
|
|
|
yield VerticalScroll(
|
|
Vertical(
|
|
|
|
# Framework & Algorithm
|
|
Static("Framework & Algorithm", classes="section_label"),
|
|
Grid(
|
|
Vertical(Label("Framework"), self.framework_dropdown, classes="input_cell"),
|
|
Vertical(Label("Algorithm"), self.algo_dropdown, classes="input_cell"),
|
|
classes="input_grid"
|
|
),
|
|
|
|
# Model name / num_env
|
|
Static("Model Configuration", classes="section_label"),
|
|
Grid(
|
|
Vertical(Label("Model Name"), self.model_name_input, classes="input_cell"),
|
|
Vertical(Label("Num Envs"), self.num_env_input, classes="input_cell"),
|
|
classes="input_grid"
|
|
),
|
|
|
|
# General Settings
|
|
Static("General Settings", classes="section_label"),
|
|
Grid(
|
|
*[
|
|
Vertical(Label(k.replace("_", " ").title()), v, classes="input_cell")
|
|
for k, v in self.settings_inputs.items()
|
|
],
|
|
classes="input_grid"
|
|
),
|
|
|
|
# Wrapper Checkboxes
|
|
Static("Wrapper Flags", classes="section_label"),
|
|
Grid(
|
|
*[
|
|
Vertical(cb, classes="input_cell")
|
|
for cb in self.wrapper_checkboxes.values()
|
|
],
|
|
classes="input_grid"
|
|
),
|
|
|
|
# Wrapper Inputs
|
|
Grid(
|
|
*[
|
|
Vertical(Label(k.replace("_", " ").title()), v, classes="input_cell")
|
|
for k, v in self.wrapper_inputs.items()
|
|
],
|
|
classes="input_grid"
|
|
),
|
|
|
|
# Filter Keys
|
|
Static("Filter Keys", classes="section_label"),
|
|
Grid(*self.filter_key_checkboxes, id="filter_grid"),
|
|
|
|
# Policy kwargs
|
|
Static("Policy Architecture", classes="section_label"),
|
|
Grid(
|
|
Vertical(Label("Net Arch"), self.net_arch_input, classes="input_cell"),
|
|
classes="input_grid"
|
|
),
|
|
|
|
# PPO Settings
|
|
Static("PPO Settings", classes="section_label"),
|
|
Grid(
|
|
*[
|
|
Vertical(Label(k.replace("_", " ").title()), v, classes="input_cell")
|
|
for k, v in self.ppo_inputs.items()
|
|
],
|
|
classes="input_grid"
|
|
),
|
|
|
|
id="form_layout",
|
|
classes="form_column"
|
|
),
|
|
id="config_scroll_container"
|
|
)
|
|
|
|
yield Button("🔍 Review Config", id="review_config_btn", classes="confirm_button")
|
|
yield AgentMFooter(compact=True)
|
|
|
|
log_with_caller("debug", "Finished composing TrainingView")
|