diff --git a/agentm/app.py b/agentm/app.py index 0b9df45..1e8f3e0 100644 --- a/agentm/app.py +++ b/agentm/app.py @@ -1,5 +1,5 @@ from textual.app import App -from agentm.views.home import HomeView +from agentm.views.game_select import HomeView from agentm.views.login import LoginView from agentm import DIAMBRA_CREDENTIALS_PATH from agentm.utils.logger import log_with_caller diff --git a/agentm/assets/headers/agent_select.txt b/agentm/assets/headers/agent_select.txt new file mode 100644 index 0000000..c57dff6 --- /dev/null +++ b/agentm/assets/headers/agent_select.txt @@ -0,0 +1,10 @@ + + + █████╗ ██████╗ ███████╗███╗ ██╗████████╗ ███████╗███████╗██╗ ███████╗ ██████╗████████╗ +██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝ ██╔════╝██╔════╝██║ ██╔════╝██╔════╝╚══██╔══╝ +███████║██║ ███╗█████╗ ██╔██╗ ██║ ██║ ███████╗█████╗ ██║ █████╗ ██║ ██║ +██╔══██║██║ ██║██╔══╝ ██║╚██╗██║ ██║ ╚════██║██╔══╝ ██║ ██╔══╝ ██║ ██║ +██║ ██║╚██████╔╝███████╗██║ ╚████║ ██║ ███████║███████╗███████╗███████╗╚██████╗ ██║ +╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚══════╝ ╚═════╝ ╚═╝ + + diff --git a/agentm/assets/headers/create_agent.txt b/agentm/assets/headers/create_agent.txt new file mode 100644 index 0000000..87c6c42 --- /dev/null +++ b/agentm/assets/headers/create_agent.txt @@ -0,0 +1,10 @@ + + + ██████╗██████╗ ███████╗ █████╗ ████████╗███████╗ █████╗ ██████╗ ███████╗███╗ ██╗████████╗ +██╔════╝██╔══██╗██╔════╝██╔══██╗╚══██╔══╝██╔════╝ ██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝ +██║ ██████╔╝█████╗ ███████║ ██║ █████╗ ███████║██║ ███╗█████╗ ██╔██╗ ██║ ██║ +██║ ██╔══██╗██╔══╝ ██╔══██║ ██║ ██╔══╝ ██╔══██║██║ ██║██╔══╝ ██║╚██╗██║ ██║ +╚██████╗██║ ██║███████╗██║ ██║ ██║ ███████╗ ██║ ██║╚██████╔╝███████╗██║ ╚████║ ██║ + ╚═════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═╝ + + diff --git a/agentm/assets/headers/select_model.txt b/agentm/assets/headers/select_model.txt new file mode 100644 index 0000000..947b136 --- /dev/null +++ b/agentm/assets/headers/select_model.txt @@ -0,0 +1,10 @@ + + +███████╗███████╗██╗ ███████╗ ██████╗████████╗ ███╗ ███╗ ██████╗ ██████╗ ███████╗██╗ +██╔════╝██╔════╝██║ ██╔════╝██╔════╝╚══██╔══╝ ████╗ ████║██╔═══██╗██╔══██╗██╔════╝██║ +███████╗█████╗ ██║ █████╗ ██║ ██║ ██╔████╔██║██║ ██║██║ ██║█████╗ ██║ +╚════██║██╔══╝ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║██║ ██║██║ ██║██╔══╝ ██║ +███████║███████╗███████╗███████╗╚██████╗ ██║ ██║ ╚═╝ ██║╚██████╔╝██████╔╝███████╗███████╗ +╚══════╝╚══════╝╚══════╝╚══════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚══════╝╚══════╝ + + diff --git a/agentm/assets/headers/select_run.txt b/agentm/assets/headers/select_run.txt new file mode 100644 index 0000000..46e805a --- /dev/null +++ b/agentm/assets/headers/select_run.txt @@ -0,0 +1,10 @@ + + +███████╗███████╗██╗ ███████╗ ██████╗████████╗ ██████╗ ██╗ ██╗███╗ ██╗ +██╔════╝██╔════╝██║ ██╔════╝██╔════╝╚══██╔══╝ ██╔══██╗██║ ██║████╗ ██║ +███████╗█████╗ ██║ █████╗ ██║ ██║ ██████╔╝██║ ██║██╔██╗ ██║ +╚════██║██╔══╝ ██║ ██╔══╝ ██║ ██║ ██╔══██╗██║ ██║██║╚██╗██║ +███████║███████╗███████╗███████╗╚██████╗ ██║ ██║ ██║╚██████╔╝██║ ╚████║ +╚══════╝╚══════╝╚══════╝╚══════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝ + + diff --git a/agentm/logic/db.py b/agentm/logic/db.py index 03ba7a2..d83fb50 100644 --- a/agentm/logic/db.py +++ b/agentm/logic/db.py @@ -5,16 +5,21 @@ CACHE_DB_PATH = Path("agentm/data/agentM.db") CACHE_DB_PATH.parent.mkdir(parents=True, exist_ok=True) def get_db_conn(): - return sqlite3.connect(CACHE_DB_PATH) + conn = sqlite3.connect(CACHE_DB_PATH) + conn.row_factory = sqlite3.Row + return conn def initialize_database(): with get_db_conn() as conn: + conn.execute("PRAGMA foreign_keys = ON;") + + # Game metadata table conn.execute(""" CREATE TABLE IF NOT EXISTS roms ( id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT NOT NULL, rom_file TEXT NOT NULL UNIQUE, - game_id TEXT NOT NULL, + game_id TEXT NOT NULL UNIQUE, sha256 TEXT, difficulty_min INTEGER, difficulty_max INTEGER, @@ -24,4 +29,77 @@ def initialize_database(): verified_at TEXT ); """) - conn.commit() \ No newline at end of file + + # Agent definitions linked to game_id + conn.execute(""" + CREATE TABLE IF NOT EXISTS agents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + game_id TEXT NOT NULL, + agent_type TEXT DEFAULT 'PPO', + framework TEXT DEFAULT 'SB3', + is_imitation BOOLEAN DEFAULT 0, + dataset_path TEXT, + config_json TEXT NOT NULL, + notes TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + last_updated TEXT DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (game_id) REFERENCES roms(game_id) + ON DELETE CASCADE ON UPDATE CASCADE + ); + """) + + # Run definitions linked to agents + conn.execute(""" + CREATE TABLE IF NOT EXISTS runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id INTEGER NOT NULL, + name TEXT NOT NULL, + config_yaml TEXT NOT NULL, + notes TEXT, + pending BOOLEAN NOT NULL DEFAULT 1, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (agent_id) REFERENCES agents(id) + ON DELETE CASCADE ON UPDATE CASCADE + ); + """) + + # Model definitions (track training progress, lineage, resume state) + conn.execute(""" + CREATE TABLE IF NOT EXISTS models ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + + agent_id INTEGER NOT NULL, + parent_model_id INTEGER, + + name TEXT NOT NULL, + notes TEXT, + + total_steps_planned INTEGER NOT NULL, + total_steps_completed INTEGER NOT NULL DEFAULT 0, + average_reward REAL, + + current_learning_rate REAL, + current_clip_range REAL, + num_envs INTEGER DEFAULT 1, -- New field added here + + config_patch_yaml TEXT, + checkpoint_path TEXT, + is_frozen BOOLEAN NOT NULL DEFAULT 0, + + status TEXT NOT NULL DEFAULT 'pending' + CHECK(status IN ('pending', 'training', 'paused', 'completed', 'failed')), + + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP, + + FOREIGN KEY (agent_id) REFERENCES agents(id) + ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (parent_model_id) REFERENCES models(id) + ON DELETE CASCADE ON UPDATE CASCADE + ); + """) + + + conn.commit() diff --git a/agentm/logic/db_functions.py b/agentm/logic/db_functions.py index fbcd433..0d1b1c1 100644 --- a/agentm/logic/db_functions.py +++ b/agentm/logic/db_functions.py @@ -118,3 +118,379 @@ def upsert_rom_record( json.dumps(keywords or []), datetime.utcnow().isoformat() )) + +def get_agents_for_game(game_id: str) -> list[dict]: + """ + Retrieve all agents associated with a given game_id. + + Args: + game_id: The unique game ID (e.g., 'sfiii3n'). + + Returns: + A list of agent metadata dictionaries. + """ + with get_db_conn() as conn: + cur = conn.execute(""" + SELECT id, name, game_id, created_at, last_updated, agent_type, framework, is_imitation + FROM agents + WHERE game_id = ? + ORDER BY created_at DESC + """, (game_id,)) + return [dict(row) for row in cur.fetchall()] + + +def insert_agent( + name: str, + game_id: str, + config_json: str, + agent_type: str = "PPO", + framework: str = "SB3", + is_imitation: bool = False, + dataset_path: str = "", + notes: str = "" +) -> None: + """ + Insert a new agent into the database. + + Args: + name: Agent name. + game_id: Game ID this agent is associated with. + config_json: JSON string of the training config. + agent_type: Algorithm type (e.g., PPO, BC). + framework: Framework name (e.g., SB3, Ray). + is_imitation: Whether this is imitation learning. + dataset_path: Optional path to dataset (for IL). + notes: Developer notes or comments. + """ + with get_db_conn() as conn: + conn.execute(""" + INSERT INTO agents ( + name, + game_id, + config_json, + agent_type, + framework, + is_imitation, + dataset_path, + notes, + created_at, + last_updated + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + name, + game_id, + config_json, + agent_type, + framework, + int(is_imitation), + dataset_path, + notes, + datetime.utcnow().isoformat(), + datetime.utcnow().isoformat() + )) + +def delete_agent_by_id(agent_id: int) -> None: + """ + Permanently delete an agent from the database by ID. + + Args: + agent_id: The unique agent ID to delete. + """ + with get_db_conn() as conn: + conn.execute("DELETE FROM agents WHERE id = ?", (agent_id,)) + + +def get_agent_by_id(agent_id: int) -> dict | None: + """ + Fetch a single agent's metadata from the database by ID. + + Args: + agent_id: The unique agent ID. + + Returns: + A dictionary of agent metadata if found, otherwise None. + """ + with get_db_conn() as conn: + cur = conn.execute(""" + SELECT id, name, game_id, created_at, last_updated, agent_type, + framework, is_imitation, dataset_path, config_json, notes + FROM agents + WHERE id = ? + """, (agent_id,)) + row = cur.fetchone() + + if row: + return dict(row) + return None + +from datetime import datetime +from agentm.logic.db import get_db_conn + + +def insert_run( + agent_id: int, + name: str, + config_yaml: str, + notes: str = "" +) -> None: + """ + Insert a new run associated with an agent. + + Args: + agent_id: The agent this run belongs to. + name: The name of the run. + config_yaml: Serialized YAML config (frozen) for this run. + notes: Optional notes about the run. + """ + with get_db_conn() as conn: + conn.execute(""" + INSERT INTO runs ( + agent_id, + name, + config_yaml, + notes, + created_at, + updated_at + ) VALUES (?, ?, ?, ?, ?, ?) + """, ( + agent_id, + name, + config_yaml, + notes, + datetime.utcnow().isoformat(), + datetime.utcnow().isoformat() + )) + + +def get_runs_for_agent(agent_id: int) -> list[dict]: + """ + Get all runs associated with a specific agent. + + Args: + agent_id: The ID of the agent. + + Returns: + A list of run dictionaries. + """ + with get_db_conn() as conn: + cur = conn.execute(""" + SELECT * FROM runs + WHERE agent_id = ? + ORDER BY created_at DESC + """, (agent_id,)) + return [dict(row) for row in cur.fetchall()] + + +def get_run_by_id(run_id: int) -> dict | None: + """ + Get a specific run by its ID. + + Args: + run_id: The ID of the run. + + Returns: + A run dictionary or None if not found. + """ + with get_db_conn() as conn: + cur = conn.execute(""" + SELECT * FROM runs + WHERE id = ? + """, (run_id,)) + row = cur.fetchone() + return dict(row) if row else None + + +def delete_run_by_id(run_id: int) -> None: + """ + Delete a run permanently from the database. + + Args: + run_id: The ID of the run to delete. + """ + with get_db_conn() as conn: + conn.execute(""" + DELETE FROM runs + WHERE id = ? + """, (run_id,)) + +from datetime import datetime +from agentm.logic.db import get_db_conn + +def insert_model( + agent_id: int, + name: str, + total_steps_planned: int, + notes: str = "", + parent_model_id: int | None = None, + average_reward: float | None = None, + current_learning_rate: float | None = None, + current_clip_range: float | None = None, + config_patch_yaml: str = "", + checkpoint_path: str = "", + is_frozen: bool = False, + status: str = "pending" +) -> None: + """ + Insert a new model entry. + """ + with get_db_conn() as conn: + conn.execute(""" + INSERT INTO models ( + agent_id, parent_model_id, name, notes, + total_steps_planned, total_steps_completed, average_reward, + current_learning_rate, current_clip_range, + config_patch_yaml, checkpoint_path, is_frozen, + status, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, 0, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + agent_id, parent_model_id, name, notes, + total_steps_planned, average_reward, + current_learning_rate, current_clip_range, + config_patch_yaml, checkpoint_path, int(is_frozen), + status, datetime.utcnow().isoformat(), datetime.utcnow().isoformat() + )) + +def get_models_for_agent(agent_id: int, include_children: bool = True) -> list[dict]: + """ + Fetch models for a given agent. + """ + with get_db_conn() as conn: + cur = conn.execute(""" + SELECT * FROM models + WHERE agent_id = ? + ORDER BY created_at DESC + """, (agent_id,)) + return [dict(row) for row in cur.fetchall()] + +def get_model_by_id(model_id: int) -> dict | None: + """ + Fetch a single model by ID. + """ + with get_db_conn() as conn: + cur = conn.execute("SELECT * FROM models WHERE id = ?", (model_id,)) + row = cur.fetchone() + return dict(row) if row else None + +def update_model_progress( + model_id: int, + total_steps_completed: int, + average_reward: float | None = None, + current_learning_rate: float | None = None, + current_clip_range: float | None = None, + checkpoint_path: str | None = None, + status: str | None = None +): + """ + Update step progress, learning rate, and optionally reward/checkpoint info. + """ + fields = ["total_steps_completed = ?"] + values = [total_steps_completed] + + if average_reward is not None: + fields.append("average_reward = ?") + values.append(average_reward) + + if current_learning_rate is not None: + fields.append("current_learning_rate = ?") + values.append(current_learning_rate) + + if current_clip_range is not None: + fields.append("current_clip_range = ?") + values.append(current_clip_range) + + if checkpoint_path is not None: + fields.append("checkpoint_path = ?") + values.append(checkpoint_path) + + if status is not None: + fields.append("status = ?") + values.append(status) + + fields.append("updated_at = ?") + values.append(datetime.utcnow().isoformat()) + + values.append(model_id) + + with get_db_conn() as conn: + conn.execute(f""" + UPDATE models SET {', '.join(fields)} + WHERE id = ? + """, tuple(values)) + +def delete_model_by_id(model_id: int) -> None: + """ + Delete a model record. + """ + with get_db_conn() as conn: + conn.execute("DELETE FROM models WHERE id = ?", (model_id,)) + +def get_models_for_run(run_id: int) -> list[dict]: + """ + Retrieve all models linked to a specific run. + Assumes a run_id is passed and models were inserted with a reference to it (if applicable). + Currently, models are linked by agent_id only, so we filter manually. + + Args: + run_id: The ID of the run. + + Returns: + A list of model dicts. + """ + with get_db_conn() as conn: + # First, get the agent_id for the run + cur = conn.execute("SELECT agent_id FROM runs WHERE id = ?", (run_id,)) + row = cur.fetchone() + if not row: + return [] + + agent_id = row["agent_id"] + # Then, get all models for that agent + cur = conn.execute(""" + SELECT * FROM models + WHERE agent_id = ? + ORDER BY created_at DESC + """, (agent_id,)) + return [dict(r) for r in cur.fetchall()] + + +def update_run( + run_id: int, + name: str | None = None, + config_yaml: str | None = None, + notes: str | None = None, + pending: bool | None = None +) -> None: + """ + Update a run by ID with any combination of fields. + """ + fields = [] + values = [] + + if name is not None: + fields.append("name = ?") + values.append(name) + + if config_yaml is not None: + fields.append("config_yaml = ?") + values.append(config_yaml) + + if notes is not None: + fields.append("notes = ?") + values.append(notes) + + if pending is not None: + fields.append("pending = ?") + values.append(int(pending)) + + if not fields: + return # Nothing to update + + fields.append("updated_at = ?") + values.append(datetime.utcnow().isoformat()) + values.append(run_id) + + with get_db_conn() as conn: + conn.execute( + f"UPDATE runs SET {', '.join(fields)} WHERE id = ?", + tuple(values) + ) diff --git a/agentm/theme/styles.base.tcss b/agentm/theme/styles.base.tcss index d32907c..4389a7c 100644 --- a/agentm/theme/styles.base.tcss +++ b/agentm/theme/styles.base.tcss @@ -12,19 +12,6 @@ Screen { border: none; } -/* === Headers === */ - -# .header { -# dock: top; -# height: 3; -# content-align: center middle; -# background: {{SURFACE_10}}; -# color: {{ACCENT}}; -# text-style: bold; -# padding: 1 2; -# border: solid {{BORDER}}; -# } - /* === Buttons === */ Button { @@ -57,6 +44,19 @@ Button:disabled { text-style: dim; } +/* === Danger Button Variant === */ +.danger_button { + background: {{SURFACE_10}}; + color: {{LIGHT}}; + border: solid {{ERROR}}; +} + +.danger_button:hover { + background: {{SURFACE_20}}; + color: {{ERROR}}; + border: solid {{ERROR}}; +} + /* === Grid Layout === */ .rom_grid { @@ -215,4 +215,111 @@ Button.game_button { align-horizontal: center; height: 100%; padding-left: 1; +} + +/* === Agent Select View === */ + +#agent_select_layout { + layout: vertical; + width: 100%; + height: auto; + padding: 2; +} + +#agent_scroll { + max-height: 35vh; + overflow-y: auto; + padding: 0 1; + border: solid {{BORDER}}; + scrollbar-gutter: stable; +} + +.agent_card { + background: {{SURFACE_10}}; + border-bottom: dashed {{BORDER}}; + padding: 1 2; + text-align: left; + color: {{FOREGROUND}}; +} + +.agent_card:last-child { + border-bottom: none; +} + +.agent_card:hover { + background: {{SURFACE_20}}; + color: {{ACCENT}}; +} + +#agent_scroll > .agent_card { + margin-bottom: 1; +} + +/* Message displayed when no agents exist (mounted manually in Python) */ +#agent_scroll > Static { + color: {{DISABLED}}; + text-style: italic; + padding: 1; +} + +#create_agent_btn { + width: 100%; + max-width: 30%; + margin-top: 1; + align: center middle; +} + +#create_agent_layout { + layout: vertical; + width: 100%; + height: auto; + padding: 2; +} + +#agent_name_input { + width: 100%; + max-width: 40%; + align: center middle; + margin: 1 0; +} + +#agent_info_panel { + layout: horizontal; + width: 100%; + height: auto; + max-height: 150vh; + padding: 0 1; + border: solid {{BORDER}}; + align-vertical: top; + content-align: center middle; +} + +.agent_info_box { + width: 70%; + height: auto; +} + +.agent_card { + max-height: 20; +} + +.agent_action_box { + width: 30%; + height: auto; + layout: vertical; + align-vertical: top; + align-horizontal: right; +} + +#create_button_container { + align-horizontal: center; + padding-top: 1; +} + +/* === Agent Home View === */ +#agent_home_menu Button { + max-width: 40%; + min-height: 3; + margin: 1 0; + align: center middle; } \ No newline at end of file diff --git a/agentm/theme/styles.tcss b/agentm/theme/styles.tcss index 020185e..d96b860 100644 --- a/agentm/theme/styles.tcss +++ b/agentm/theme/styles.tcss @@ -12,19 +12,6 @@ Screen { border: none; } -/* === Headers === */ - -# .header { -# dock: top; -# height: 3; -# content-align: center middle; -# background: #282828; -# color: #ed7d3a; -# text-style: bold; -# padding: 1 2; -# border: solid #3a9bed; -# } - /* === Buttons === */ Button { @@ -57,6 +44,19 @@ Button:disabled { text-style: dim; } +/* === Danger Button Variant === */ +.danger_button { + background: #282828; + color: #ffffff; + border: solid red; +} + +.danger_button:hover { + background: #3f3f3f; + color: red; + border: solid red; +} + /* === Grid Layout === */ .rom_grid { @@ -215,4 +215,111 @@ Button.game_button { align-horizontal: center; height: 100%; padding-left: 1; +} + +/* === Agent Select View === */ + +#agent_select_layout { + layout: vertical; + width: 100%; + height: auto; + padding: 2; +} + +#agent_scroll { + max-height: 35vh; + overflow-y: auto; + padding: 0 1; + border: solid #3a9bed; + scrollbar-gutter: stable; +} + +.agent_card { + background: #282828; + border-bottom: dashed #3a9bed; + padding: 1 2; + text-align: left; + color: #f0f0f0; +} + +.agent_card:last-child { + border-bottom: none; +} + +.agent_card:hover { + background: #3f3f3f; + color: #ed7d3a; +} + +#agent_scroll > .agent_card { + margin-bottom: 1; +} + +/* Message displayed when no agents exist (mounted manually in Python) */ +#agent_scroll > Static { + color: #999999; + text-style: italic; + padding: 1; +} + +#create_agent_btn { + width: 100%; + max-width: 30%; + margin-top: 1; + align: center middle; +} + +#create_agent_layout { + layout: vertical; + width: 100%; + height: auto; + padding: 2; +} + +#agent_name_input { + width: 100%; + max-width: 40%; + align: center middle; + margin: 1 0; +} + +#agent_info_panel { + layout: horizontal; + width: 100%; + height: auto; + max-height: 150vh; + padding: 0 1; + border: solid #3a9bed; + align-vertical: top; + content-align: center middle; +} + +.agent_info_box { + width: 70%; + height: auto; +} + +.agent_card { + max-height: 20; +} + +.agent_action_box { + width: 30%; + height: auto; + layout: vertical; + align-vertical: top; + align-horizontal: right; +} + +#create_button_container { + align-horizontal: center; + padding-top: 1; +} + +/* === Agent Home View === */ +#agent_home_menu Button { + max-width: 40%; + min-height: 3; + margin: 1 0; + align: center middle; } \ No newline at end of file diff --git a/agentm/views/agent_home.py b/agentm/views/agent_home.py new file mode 100644 index 0000000..3974a39 --- /dev/null +++ b/agentm/views/agent_home.py @@ -0,0 +1,67 @@ +from textual.screen import Screen +from textual.widgets import Button, Static +from textual.containers import Vertical +from agentm.utils.logger import log_with_caller +from agentm.components.footer import AgentMFooter +from agentm.theme.palette import get_theme +from agentm.logic.db_functions import get_models_for_run +from agentm.views.model_select import ModelSelectView +from pyfiglet import Figlet + +palette = get_theme() + +class AgentHomeView(Screen): + BINDINGS = [("escape", "app.pop_screen", "Back")] + + def __init__(self, agent: dict, run: dict): + super().__init__() + self.agent_metadata = agent + self.run_metadata = run + self.models_exist = False + + def compose(self): + f = Figlet(font="ansi_shadow") + ascii_header = f.renderText(self.agent_metadata["name"]) + + yield Static(f"[{palette.ACCENT}]{ascii_header}[/{palette.ACCENT}]", classes="header", expand=False) + yield Static( + f"[b]Agent ID:[/] {self.agent_metadata['id']} | " + f"[b]Run:[/] {self.run_metadata['name']}", + classes="subheader", + expand=False + ) + + # Buttons - only Eval and Submit may be disabled + self.train_btn = Button("🚀 Train", id="train_btn", classes="confirm_button") + self.eval_btn = Button("🎮 Eval", id="eval_btn", classes="confirm_button", disabled=True) + self.submit_btn = Button("📦 Submit", id="submit_btn", classes="confirm_button", disabled=True) + + yield Vertical( + self.train_btn, + self.eval_btn, + self.submit_btn, + id="agent_home_menu", + classes="centered_layout" + ) + + yield AgentMFooter(compact=True) + + async def on_mount(self): + log_with_caller("info", f"Entered AgentHomeView for agent '{self.agent_metadata['name']}'") + models = get_models_for_run(self.run_metadata["id"]) + self.models_exist = len(models) > 0 + self.eval_btn.disabled = not self.models_exist + self.submit_btn.disabled = not self.models_exist + log_with_caller("debug", f"Found {len(models)} model(s) for run '{self.run_metadata['name']}'") + + async def on_button_pressed(self, event: Button.Pressed) -> None: + action = event.button.id.replace("_btn", "") # 'train', 'eval', or 'submit' + log_with_caller("info", f"{action.title()} selected for run '{self.run_metadata['name']}'") + + # Always go to model select before action + await self.app.push_screen( + ModelSelectView( + agent_metadata=self.agent_metadata, + mode=action + ) + ) diff --git a/agentm/views/agent_select.py b/agentm/views/agent_select.py new file mode 100644 index 0000000..69e0323 --- /dev/null +++ b/agentm/views/agent_select.py @@ -0,0 +1,166 @@ +from textual.screen import Screen +from textual.containers import Vertical, Horizontal, VerticalScroll +from textual.widgets import Static, Button +from agentm.logic.db_functions import get_agents_for_game, delete_agent_by_id +from agentm.theme.palette import get_theme +from agentm.utils.logger import log_with_caller +from agentm.components.footer import AgentMFooter +from .create_agent import CreateAgentView +from .select_run import SelectRunView +from pathlib import Path + +palette = get_theme() + +class AgentCard(Static): + def __init__(self, agent_data: dict, parent_view): + super().__init__(classes="agent_card") + self.agent_data = agent_data + self.parent_view = parent_view + + def render(self) -> str: + return f""" +[bold {palette.ACCENT}]{self.agent_data['name']}[/] +[dim]ID:[/] {self.agent_data['id']} +[dim]Created:[/] {self.agent_data.get('created_at', '—')} +[dim]Game ID:[/] {self.agent_data['game_id']} + """.strip() + + async def on_click(self): + await self.parent_view.display_agent_info(self.agent_data) + + +class AgentSelectView(Screen): + BINDINGS = [ + ("escape", "app.pop_screen", "Back"), + ("n", "create_agent", "New Agent"), + ("r", "refresh_agents", "Refresh"), + ("up", "scroll_up", "Scroll Up"), + ("down", "scroll_down", "Scroll Down"), + ] + + def __init__(self, game_metadata: dict): + super().__init__() + self.game_metadata = game_metadata + self.selected_agent = None + self.delete_mode = False + + def compose(self): + ascii_path = Path(__file__).parent.parent / "assets" / "headers" / "agent_select.txt" + try: + header_text = ascii_path.read_text() + except FileNotFoundError: + header_text = "=== SELECT AGENT ===" + + self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header") + self.subheader = Static(f"[b]{self.game_metadata['title']}[/b]", classes="subheader") + + self.agent_list = VerticalScroll(id="agent_scroll") + self.agent_info = Horizontal(id="agent_info_panel") + + self.confirm_button = Button("✅ Select Agent", id="confirm_agent_btn", classes="confirm_button", disabled=True) + self.delete_button = Button("🗑️ Delete Agent", id="delete_agent_btn", classes="danger_button", disabled=True) + self.create_button = Button("➕ Create Agent", id="create_agent_btn", classes="confirm_button") + + yield Vertical( + self.header, + self.subheader, + self.agent_list, + self.agent_info, + self.create_button, + AgentMFooter(compact=True), + id="agent_select_layout" + ) + + async def on_mount(self): + log_with_caller("debug", f"Mounted AgentSelectView for game_id: {self.game_metadata['game_id']}") + await self.refresh_agent_list() + + async def on_resume(self): + await self.refresh_agent_list() + + async def refresh_agent_list(self): + try: + for child in list(self.agent_list.children): + await child.remove() + + agents = get_agents_for_game(self.game_metadata["game_id"]) + log_with_caller("info", f"Refreshed: {len(agents)} agents found for {self.game_metadata['game_id']}") + + if agents: + for agent in agents: + await self.agent_list.mount(AgentCard(agent, self)) + else: + await self.agent_list.mount(Static("[dim]No agents found for this game.[/dim]")) + + self.selected_agent = None + self.delete_mode = False + await self.display_agent_info(None) + except Exception as e: + log_with_caller("error", f"Error rendering agent list: {e}") + + async def display_agent_info(self, agent: dict | None): + self.selected_agent = agent + self.delete_mode = False + await self.agent_info.remove_children() + + if not agent: + await self.agent_info.mount(Static("[dim]Select an agent to view details[/dim]")) + self.confirm_button.disabled = True + self.delete_button.disabled = True + return + + from rich.panel import Panel + from rich.table import Table + + table = Table.grid(padding=(0, 1)) + table.add_column("Key", style="bold underline") + table.add_column("Value", style=palette.ACCENT, overflow="fold") + + table.add_row("Name", agent["name"]) + table.add_row("Agent ID", str(agent["id"])) + table.add_row("Game ID", agent["game_id"]) + table.add_row("Created", agent.get("created_at", "—")) + table.add_row("Type", agent.get("agent_type", "—")) + table.add_row("Framework", agent.get("framework", "—")) + table.add_row("Imitation", str(agent.get("is_imitation", False))) + + info_panel = Static(Panel(table, title="Agent Info", border_style=palette.BORDER), classes="agent_info_box") + button_column = Vertical(self.confirm_button, self.delete_button, classes="agent_action_box") + + await self.agent_info.mount(info_panel, button_column) + self.confirm_button.disabled = False + self.delete_button.disabled = False + self.delete_button.label = "🗑️ Delete Agent" + + async def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "create_agent_btn": + log_with_caller("info", "Create Agent button pressed.") + await self.app.push_screen(CreateAgentView(self.game_metadata)) + + elif event.button.id == "confirm_agent_btn" and self.selected_agent: + log_with_caller("info", f"Agent confirmed: {self.selected_agent['name']}") + await self.app.push_screen(SelectRunView(self.selected_agent)) + + elif event.button.id == "delete_agent_btn" and self.selected_agent: + if not self.delete_mode: + self.delete_mode = True + self.delete_button.label = "❌ Confirm Delete" + log_with_caller("warning", f"Pending deletion for agent: {self.selected_agent['name']}") + else: + log_with_caller("info", f"Deleting agent: {self.selected_agent['name']} ({self.selected_agent['id']})") + delete_agent_by_id(self.selected_agent["id"]) + await self.refresh_agent_list() + + async def action_create_agent(self): + await self.app.push_screen(CreateAgentView(self.game_metadata)) + + async def action_refresh_agents(self): + await self.refresh_agent_list() + + async def action_scroll_up(self): + scroll = self.query_one("#agent_scroll", VerticalScroll) + scroll.scroll_up() + + async def action_scroll_down(self): + scroll = self.query_one("#agent_scroll", VerticalScroll) + scroll.scroll_down() diff --git a/agentm/views/create_agent.py b/agentm/views/create_agent.py new file mode 100644 index 0000000..4969d06 --- /dev/null +++ b/agentm/views/create_agent.py @@ -0,0 +1,61 @@ +from textual.screen import Screen +from textual.containers import Vertical +from textual.widgets import Static, Button, Input +from agentm.components.footer import AgentMFooter +from agentm.theme.palette import get_theme +from agentm.utils.logger import log_with_caller + +from pathlib import Path +from agentm.logic.db_functions import insert_agent # ✅ Add this import + +palette = get_theme() + + +class CreateAgentView(Screen): + BINDINGS = [("escape", "app.pop_screen", "Back")] + + def __init__(self, game_metadata: dict): + super().__init__() + self.game_metadata = game_metadata + + def compose(self): + # Load ASCII header + ascii_path = Path(__file__).parent.parent / "assets" / "headers" / "create_agent.txt" + try: + header_text = ascii_path.read_text() + except FileNotFoundError: + header_text = "### Create Agent ###" + + self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header") + + self.name_input = Input(placeholder="Enter agent name...", id="agent_name_input") + self.create_button = Button("🚀 Create Agent", id="create_button", classes="confirm_button") + + yield Vertical( + self.header, + self.name_input, + self.create_button, + AgentMFooter(compact=True), + id="create_agent_layout", + classes="centered_layout" + ) + + async def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "create_button": + name = self.name_input.value.strip() + + if name: + log_with_caller("info", f"Creating agent '{name}' for game_id: {self.game_metadata['game_id']}") + + # ✅ Insert the agent into the DB + insert_agent( + name=name, + game_id=self.game_metadata["game_id"], + config_json="{}", # Replace with actual config later + notes="Created via CreateAgentView" + ) + + # ✅ Dismiss screen + await self.app.pop_screen() + else: + log_with_caller("warning", "Tried to create agent with empty name.") diff --git a/agentm/views/eval.py b/agentm/views/eval.py deleted file mode 100644 index e69de29..0000000 diff --git a/agentm/views/evaluation.py b/agentm/views/evaluation.py new file mode 100644 index 0000000..95a896f --- /dev/null +++ b/agentm/views/evaluation.py @@ -0,0 +1,9 @@ +from textual.screen import Screen +from textual.widgets import Static +from agentm.theme.palette import get_theme + +palette = get_theme() + +class EvaluationView(Screen): + def compose(self): + yield Static(f"[{palette.ACCENT}]Evaluation View Placeholder[/]", classes="centered_layout") diff --git a/agentm/views/home.py b/agentm/views/game_select.py similarity index 86% rename from agentm/views/home.py rename to agentm/views/game_select.py index c6a79a4..42b9aaf 100644 --- a/agentm/views/home.py +++ b/agentm/views/game_select.py @@ -15,6 +15,8 @@ from agentm.logic.roms import get_verified_roms, GAME_FILES from agentm.theme.palette import get_theme from agentm.components.footer import AgentMFooter +from .agent_select import AgentSelectView + palette = get_theme() @@ -72,7 +74,14 @@ class GameCardButton(Button): class HomeView(Screen): - BINDINGS = [("escape", "app.quit", "Quit")] + BINDINGS = [ + ("escape", "app.quit", "Quit"), + ("up", "scroll_up", "Scroll Up"), + ("down", "scroll_down", "Scroll Down"), + ("enter", "confirm_game", "Confirm Game"), + ("r", "refresh_roms", "Refresh ROMs"), + ] + def highlight_selected(self, selected_widget: GameCardButton): for card in self.rom_grid.children: @@ -195,4 +204,24 @@ class HomeView(Screen): async def on_button_pressed(self, event: Button.Pressed) -> None: if event.button.id == "confirm_button" and self.selected_game: - await self.app.push_screen("training", self.selected_game) + await self.app.push_screen(AgentSelectView(self.selected_game)) + + + async def action_scroll_up(self): + scroll = self.query_one("#rom_grid_scroll", VerticalScroll) + scroll.scroll_up() + + async def action_scroll_down(self): + scroll = self.query_one("#rom_grid_scroll", VerticalScroll) + scroll.scroll_down() + + async def action_confirm_game(self): + if self.selected_game: + log_with_caller("info", f"Keyboard: Confirming game {self.selected_game['title']}") + await self.app.push_screen(AgentSelectView(self.selected_game)) + else: + log_with_caller("warning", "Keyboard: Tried to confirm with no game selected.") + + async def action_refresh_roms(self): + log_with_caller("info", "Keyboard: Refreshing ROM list") + self.run_worker(self.run_verification, thread=True, exclusive=True, name="rom-verification") diff --git a/agentm/views/model_select.py b/agentm/views/model_select.py new file mode 100644 index 0000000..0e945e2 --- /dev/null +++ b/agentm/views/model_select.py @@ -0,0 +1,146 @@ +from typing import Literal +from textual.screen import Screen +from textual.containers import Vertical, Horizontal, VerticalScroll +from textual.widgets import Static, Button +from agentm.logic.db_functions import get_models_for_agent +from agentm.theme.palette import get_theme +from agentm.utils.logger import log_with_caller +from agentm.components.footer import AgentMFooter +from pathlib import Path + +# Import destination views (these are placeholders; replace as needed) +from agentm.views.training import TrainingView +from agentm.views.evaluation import EvaluationView +from agentm.views.submission import SubmissionView + +# Ensure the palette is loaded +palette = get_theme() + + +class ModelCard(Static): + def __init__(self, model_data: dict, parent_view): + super().__init__(classes="agent_card") + self.model_data = model_data + self.parent_view = parent_view + + def render(self) -> str: + return f""" +[bold {palette.ACCENT}]{self.model_data['name']}[/] +[dim]Steps:[/] {self.model_data['total_steps_completed']} / {self.model_data['total_steps_planned']} +[dim]Avg Reward:[/] {self.model_data.get('average_reward', '—')} +[dim]Status:[/] {self.model_data['status']} +[dim]Created:[/] {self.model_data.get('created_at', '—')} + """.strip() + + async def on_click(self): + await self.parent_view.display_model_info(self.model_data) + + +class ModelSelectView(Screen): + BINDINGS = [ + ("escape", "app.pop_screen", "Back"), + ("r", "refresh_models", "Refresh"), + ] + + def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"]): + super().__init__() + self.agent_metadata = agent_metadata + self.mode = mode + self.selected_model = None + + def compose(self): + header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_model.txt" + try: + header_text = header_path.read_text() + except FileNotFoundError: + header_text = "=== SELECT MODEL ===" + + self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header") + self.subheader = Static(f"[b]{self.agent_metadata['name']}[/b]", classes="subheader") + + self.model_list = VerticalScroll(id="agent_scroll") + self.model_info = Horizontal(id="agent_info_panel") + + self.select_button = Button("✅ Select Model", id="select_model_btn", classes="confirm_button", disabled=True) + self.create_button = Button("➕ Create New Model", id="create_model_btn", classes="confirm_button") + + yield Vertical( + self.header, + self.subheader, + self.model_list, + self.model_info, + self.select_button, + self.create_button, + AgentMFooter(compact=True), + id="agent_select_layout" + ) + + async def on_mount(self): + log_with_caller("debug", f"Mounted ModelSelectView for agent_id={self.agent_metadata['id']} mode={self.mode}") + await self.refresh_model_list() + + async def refresh_model_list(self): + try: + for child in list(self.model_list.children): + await child.remove() + + models = get_models_for_agent(self.agent_metadata["id"]) + log_with_caller("info", f"Refreshed: {len(models)} models found for agent {self.agent_metadata['name']}") + + if models: + for model in models: + await self.model_list.mount(ModelCard(model, self)) + else: + await self.model_list.mount(Static("[dim]No models found for this agent.[/dim]")) + + self.selected_model = None + await self.display_model_info(None) + except Exception as e: + log_with_caller("error", f"Error rendering model list: {e}") + + async def display_model_info(self, model: dict | None): + self.selected_model = model + await self.model_info.remove_children() + + if not model: + await self.model_info.mount(Static("[dim]Select a model to view details[/dim]")) + self.select_button.disabled = True + return + + from rich.panel import Panel + from rich.table import Table + + table = Table.grid(padding=(0, 1)) + table.add_column("Key", style="bold underline") + table.add_column("Value", style=palette.ACCENT, overflow="fold") + + table.add_row("Name", model["name"]) + table.add_row("Status", model["status"]) + table.add_row("Steps", f"{model['total_steps_completed']} / {model['total_steps_planned']}") + table.add_row("Avg Reward", str(model.get("average_reward", "—"))) + table.add_row("Learning Rate", str(model.get("current_learning_rate", "—"))) + table.add_row("Clip Range", str(model.get("current_clip_range", "—"))) + table.add_row("Checkpoint", model.get("checkpoint_path", "—")) + table.add_row("Created", model.get("created_at", "—")) + + info_panel = Static(Panel(table, title="Model Info", border_style=palette.BORDER), classes="agent_info_box") + await self.model_info.mount(info_panel) + self.select_button.disabled = False + + async def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "select_model_btn" and self.selected_model: + log_with_caller("info", f"Model confirmed: {self.selected_model['name']} for mode={self.mode}") + match self.mode: + case "train": + await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=self.selected_model)) + case "eval": + await self.app.push_screen(EvaluationView(agent=self.agent_metadata, model=self.selected_model)) + case "submit": + await self.app.push_screen(SubmissionView(agent=self.agent_metadata, model=self.selected_model)) + + elif event.button.id == "create_model_btn": + log_with_caller("info", f"Creating new model for agent: {self.agent_metadata['name']}") + await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None)) # New model starts from training + + async def action_refresh_models(self): + await self.refresh_model_list() diff --git a/agentm/views/select_run.py b/agentm/views/select_run.py new file mode 100644 index 0000000..b85826b --- /dev/null +++ b/agentm/views/select_run.py @@ -0,0 +1,147 @@ +from textual.screen import Screen +from textual.containers import Vertical, Horizontal, VerticalScroll +from textual.widgets import Static, Button +from agentm.logic.db_functions import get_runs_for_agent, insert_run +from agentm.theme.palette import get_theme +from agentm.utils.logger import log_with_caller +from agentm.components.footer import AgentMFooter +from agentm.views.agent_home import AgentHomeView +from pathlib import Path +from datetime import datetime + +palette = get_theme() + + +class RunCard(Static): + def __init__(self, run_data: dict, parent_view): + super().__init__(classes="agent_card") + self.run_data = run_data + self.parent_view = parent_view + + def render(self) -> str: + return f""" +[bold {palette.ACCENT}]{self.run_data['name']}[/] +[dim]Created:[/] {self.run_data.get('created_at', '—')} + """.strip() + + async def on_click(self): + await self.parent_view.display_run_info(self.run_data) + + + +class SelectRunView(Screen): + BINDINGS = [ + ("escape", "app.pop_screen", "Back"), + ("r", "refresh_runs", "Refresh"), + ] + + def __init__(self, agent_metadata: dict): + super().__init__() + self.agent_metadata = agent_metadata + self.selected_run = None + + def compose(self): + header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_run.txt" + try: + header_text = header_path.read_text() + except FileNotFoundError: + header_text = "=== SELECT RUN ===" + + self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header") + self.subheader = Static(f"[b]{self.agent_metadata['name']}[/b]", classes="subheader") + + self.run_list = VerticalScroll(id="agent_scroll") + self.run_info = Horizontal(id="agent_info_panel") + + self.select_button = Button("✅ Select Run", id="select_run_btn", classes="confirm_button", disabled=True) + self.create_button = Button("➕ Create New Run", id="create_run_btn", classes="confirm_button") + + yield Vertical( + self.header, + self.subheader, + self.run_list, + self.run_info, + self.select_button, + self.create_button, + AgentMFooter(compact=True), + id="agent_select_layout" + ) + + async def on_mount(self): + log_with_caller("debug", f"Mounted SelectRunView for agent_id: {self.agent_metadata['id']}") + await self.refresh_run_list() + + async def refresh_run_list(self): + try: + for child in list(self.run_list.children): + await child.remove() + + runs = get_runs_for_agent(self.agent_metadata["id"]) + log_with_caller("info", f"Refreshed: {len(runs)} runs found for agent {self.agent_metadata['name']}") + + if runs: + for run in runs: + await self.run_list.mount(RunCard(run, self)) + else: + await self.run_list.mount(Static("[dim]No runs found for this agent.[/dim]")) + + self.selected_run = None + await self.display_run_info(None) + except Exception as e: + log_with_caller("error", f"Error rendering run list: {e}") + + async def display_run_info(self, run: dict | None): + self.selected_run = run + await self.run_info.remove_children() + + if not run: + await self.run_info.mount(Static("[dim]Select a run to view details[/dim]")) + self.select_button.disabled = True + return + + from rich.panel import Panel + from rich.table import Table + + table = Table.grid(padding=(0, 1)) + table.add_column("Key", style="bold underline") + table.add_column("Value", style=palette.ACCENT, overflow="fold") + + table.add_row("Name", run["name"]) + table.add_row("Created", run.get("created_at", "—")) + table.add_row("Notes", run.get("notes", "—")) + + info_panel = Static(Panel(table, title="Run Info", border_style=palette.BORDER), classes="agent_info_box") + await self.run_info.mount(info_panel) + self.select_button.disabled = False + + async def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "select_run_btn" and self.selected_run: + log_with_caller("info", f"Run confirmed: {self.selected_run['name']}") + await self.app.push_screen(AgentHomeView(agent=self.agent_metadata, run=self.selected_run)) + + elif event.button.id == "create_run_btn": + # Auto-generate name and basic config + run_name = f"{self.agent_metadata['name']} - {datetime.utcnow().strftime('%Y%m%d-%H%M%S')}" + config_yaml = "# Auto-generated config for new run\n" + notes = f"Initial run for agent {self.agent_metadata['name']}" + + # Insert into DB + insert_run( + agent_id=self.agent_metadata["id"], + name=run_name, + config_yaml=config_yaml, + notes=notes + ) + + # Fetch the new run and push to AgentHomeView + runs = get_runs_for_agent(self.agent_metadata["id"]) + new_run = runs[0] if runs else None + + if new_run: + log_with_caller("info", f"Created and selected new run: {new_run['name']}") + await self.app.push_screen(AgentHomeView(agent=self.agent_metadata, run=new_run)) + else: + log_with_caller("error", "Failed to retrieve newly created run") + + async def action_refresh_runs(self): + await self.refresh_run_list() diff --git a/agentm/views/submission.py b/agentm/views/submission.py new file mode 100644 index 0000000..279280c --- /dev/null +++ b/agentm/views/submission.py @@ -0,0 +1,9 @@ +from textual.screen import Screen +from textual.widgets import Static +from agentm.theme.palette import get_theme + +palette = get_theme() + +class SubmissionView(Screen): + def compose(self): + yield Static(f"[{palette.ACCENT}]Submission View Placeholder[/]", classes="centered_layout") diff --git a/agentm/views/training.py b/agentm/views/training.py index e69de29..0fdc834 100644 --- a/agentm/views/training.py +++ b/agentm/views/training.py @@ -0,0 +1,106 @@ +from textual.screen import Screen +from textual.widgets import Static, Input, Button +from textual.containers import Vertical, Horizontal +from agentm.theme.palette import get_theme +from agentm.utils.logger import log_with_caller +from agentm.logic.db_functions import insert_model, update_run_pending +from datetime import datetime + +palette = get_theme() + +class TrainingView(Screen): + BINDINGS = [ + ("escape", "app.pop_screen", "Back") + ] + + def __init__(self, agent: dict, model: dict | None, run: dict): + super().__init__() + self.agent_metadata = agent + self.model_metadata = model + self.run_metadata = run + self.is_pending_run = run.get("pending", True) + + def compose(self): + log_with_caller("info", f"Opening TrainingView for agent='{self.agent_metadata['name']}', run='{self.run_metadata['name']}', pending={self.is_pending_run}") + + yield Static(f"[{palette.ACCENT} bold]Training: {self.agent_metadata['name']}[/]", classes="header") + + if self.is_pending_run: + yield self.render_pending_setup() + else: + yield self.render_resume_options() + + def render_pending_setup(self): + yield Static("[b]Initial Model Setup[/b]", classes="subheader") + + self.name_input = Input(placeholder="Model Name", id="model_name") + self.steps_input = Input(placeholder="Total Training Steps", id="total_steps") + self.lr_input = Input(placeholder="Initial Learning Rate", id="learning_rate") + self.clip_input = Input(placeholder="Initial Clip Range", id="clip_range") + self.notes_input = Input(placeholder="Notes", id="notes") + + self.confirm_button = Button("✅ Create & Start Training", id="confirm_model_btn", classes="confirm_button") + + return Vertical( + self.name_input, + self.steps_input, + self.lr_input, + self.clip_input, + self.notes_input, + self.confirm_button, + classes="centered_layout" + ) + + def render_resume_options(self): + model = self.model_metadata + from rich.table import Table + from rich.panel import Panel + + table = Table.grid(padding=(0, 1)) + table.add_column("Key", style="bold underline") + table.add_column("Value", style=palette.ACCENT, overflow="fold") + + table.add_row("Name", model["name"]) + table.add_row("Steps", f"{model['total_steps_completed']} / {model['total_steps_planned']}") + table.add_row("Reward", str(model.get("average_reward", "—"))) + table.add_row("Learning Rate", str(model.get("current_learning_rate", "—"))) + table.add_row("Clip Range", str(model.get("current_clip_range", "—"))) + table.add_row("Created", model.get("created_at", "—")) + + return Vertical( + Static(Panel(table, title="Model Info", border_style=palette.BORDER), classes="agent_info_box"), + Button("⏯️ Resume Training", id="resume_training_btn", classes="confirm_button"), + classes="centered_layout" + ) + + async def on_button_pressed(self, event: Button.Pressed) -> None: + if event.button.id == "confirm_model_btn": + try: + name = self.name_input.value.strip() + total_steps = int(self.steps_input.value) + learning_rate = float(self.lr_input.value) + clip_range = float(self.clip_input.value) + notes = self.notes_input.value.strip() + + log_with_caller("info", f"Creating new model: {name} for agent_id={self.agent_metadata['id']}") + + insert_model( + agent_id=self.agent_metadata["id"], + name=name, + total_steps_planned=total_steps, + current_learning_rate=learning_rate, + current_clip_range=clip_range, + notes=notes + ) + + update_run_pending(self.run_metadata["id"], False) + log_with_caller("info", f"Model '{name}' created and run marked as not pending") + + await self.app.pop_screen() + + except Exception as e: + log_with_caller("error", f"Failed to create model: {e}") + + elif event.button.id == "resume_training_btn": + log_with_caller("info", f"Resuming training for model: {self.model_metadata['name']}") + await self.app.pop_screen() diff --git a/requirements.txt b/requirements.txt index bbde3e2..5e6f7e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ tensorboard requests sqlite3 rich-pixels -pillow \ No newline at end of file +pillow +pyfiglet \ No newline at end of file