Refactor agent management views and styles
- Removed commented-out header styles from styles.base.tcss and styles.tcss. - Added new styles for danger buttons and agent selection views in styles.base.tcss and styles.tcss. - Implemented AgentHomeView to manage agent actions and display metadata. - Created AgentSelectView for selecting agents with a new layout and functionality. - Added CreateAgentView for creating new agents with input validation. - Removed obsolete eval.py and replaced it with evaluation.py. - Developed GameSelectView for selecting games with a dynamic loading interface. - Introduced ModelSelectView for selecting models associated with agents. - Created SelectRunView for managing runs associated with agents. - Added SubmissionView and TrainingView for handling model training and submission processes. - Updated requirements.txt to include pyfiglet for ASCII art rendering.
This commit is contained in:
parent
d62820dd80
commit
4500bfd388
@ -1,5 +1,5 @@
|
|||||||
from textual.app import App
|
from textual.app import App
|
||||||
from agentm.views.home import HomeView
|
from agentm.views.game_select import HomeView
|
||||||
from agentm.views.login import LoginView
|
from agentm.views.login import LoginView
|
||||||
from agentm import DIAMBRA_CREDENTIALS_PATH
|
from agentm import DIAMBRA_CREDENTIALS_PATH
|
||||||
from agentm.utils.logger import log_with_caller
|
from agentm.utils.logger import log_with_caller
|
||||||
|
|||||||
10
agentm/assets/headers/agent_select.txt
Normal file
10
agentm/assets/headers/agent_select.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
█████╗ ██████╗ ███████╗███╗ ██╗████████╗ ███████╗███████╗██╗ ███████╗ ██████╗████████╗
|
||||||
|
██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝ ██╔════╝██╔════╝██║ ██╔════╝██╔════╝╚══██╔══╝
|
||||||
|
███████║██║ ███╗█████╗ ██╔██╗ ██║ ██║ ███████╗█████╗ ██║ █████╗ ██║ ██║
|
||||||
|
██╔══██║██║ ██║██╔══╝ ██║╚██╗██║ ██║ ╚════██║██╔══╝ ██║ ██╔══╝ ██║ ██║
|
||||||
|
██║ ██║╚██████╔╝███████╗██║ ╚████║ ██║ ███████║███████╗███████╗███████╗╚██████╗ ██║
|
||||||
|
╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚══════╝ ╚═════╝ ╚═╝
|
||||||
|
|
||||||
|
|
||||||
10
agentm/assets/headers/create_agent.txt
Normal file
10
agentm/assets/headers/create_agent.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
██████╗██████╗ ███████╗ █████╗ ████████╗███████╗ █████╗ ██████╗ ███████╗███╗ ██╗████████╗
|
||||||
|
██╔════╝██╔══██╗██╔════╝██╔══██╗╚══██╔══╝██╔════╝ ██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝
|
||||||
|
██║ ██████╔╝█████╗ ███████║ ██║ █████╗ ███████║██║ ███╗█████╗ ██╔██╗ ██║ ██║
|
||||||
|
██║ ██╔══██╗██╔══╝ ██╔══██║ ██║ ██╔══╝ ██╔══██║██║ ██║██╔══╝ ██║╚██╗██║ ██║
|
||||||
|
╚██████╗██║ ██║███████╗██║ ██║ ██║ ███████╗ ██║ ██║╚██████╔╝███████╗██║ ╚████║ ██║
|
||||||
|
╚═════╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═╝
|
||||||
|
|
||||||
|
|
||||||
10
agentm/assets/headers/select_model.txt
Normal file
10
agentm/assets/headers/select_model.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
███████╗███████╗██╗ ███████╗ ██████╗████████╗ ███╗ ███╗ ██████╗ ██████╗ ███████╗██╗
|
||||||
|
██╔════╝██╔════╝██║ ██╔════╝██╔════╝╚══██╔══╝ ████╗ ████║██╔═══██╗██╔══██╗██╔════╝██║
|
||||||
|
███████╗█████╗ ██║ █████╗ ██║ ██║ ██╔████╔██║██║ ██║██║ ██║█████╗ ██║
|
||||||
|
╚════██║██╔══╝ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║██║ ██║██║ ██║██╔══╝ ██║
|
||||||
|
███████║███████╗███████╗███████╗╚██████╗ ██║ ██║ ╚═╝ ██║╚██████╔╝██████╔╝███████╗███████╗
|
||||||
|
╚══════╝╚══════╝╚══════╝╚══════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚══════╝╚══════╝
|
||||||
|
|
||||||
|
|
||||||
10
agentm/assets/headers/select_run.txt
Normal file
10
agentm/assets/headers/select_run.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
███████╗███████╗██╗ ███████╗ ██████╗████████╗ ██████╗ ██╗ ██╗███╗ ██╗
|
||||||
|
██╔════╝██╔════╝██║ ██╔════╝██╔════╝╚══██╔══╝ ██╔══██╗██║ ██║████╗ ██║
|
||||||
|
███████╗█████╗ ██║ █████╗ ██║ ██║ ██████╔╝██║ ██║██╔██╗ ██║
|
||||||
|
╚════██║██╔══╝ ██║ ██╔══╝ ██║ ██║ ██╔══██╗██║ ██║██║╚██╗██║
|
||||||
|
███████║███████╗███████╗███████╗╚██████╗ ██║ ██║ ██║╚██████╔╝██║ ╚████║
|
||||||
|
╚══════╝╚══════╝╚══════╝╚══════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝
|
||||||
|
|
||||||
|
|
||||||
@ -5,16 +5,21 @@ CACHE_DB_PATH = Path("agentm/data/agentM.db")
|
|||||||
CACHE_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
CACHE_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def get_db_conn():
|
def get_db_conn():
|
||||||
return sqlite3.connect(CACHE_DB_PATH)
|
conn = sqlite3.connect(CACHE_DB_PATH)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
def initialize_database():
|
def initialize_database():
|
||||||
with get_db_conn() as conn:
|
with get_db_conn() as conn:
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
|
||||||
|
# Game metadata table
|
||||||
conn.execute("""
|
conn.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS roms (
|
CREATE TABLE IF NOT EXISTS roms (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
title TEXT NOT NULL,
|
title TEXT NOT NULL,
|
||||||
rom_file TEXT NOT NULL UNIQUE,
|
rom_file TEXT NOT NULL UNIQUE,
|
||||||
game_id TEXT NOT NULL,
|
game_id TEXT NOT NULL UNIQUE,
|
||||||
sha256 TEXT,
|
sha256 TEXT,
|
||||||
difficulty_min INTEGER,
|
difficulty_min INTEGER,
|
||||||
difficulty_max INTEGER,
|
difficulty_max INTEGER,
|
||||||
@ -24,4 +29,77 @@ def initialize_database():
|
|||||||
verified_at TEXT
|
verified_at TEXT
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
conn.commit()
|
|
||||||
|
# Agent definitions linked to game_id
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS agents (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
game_id TEXT NOT NULL,
|
||||||
|
agent_type TEXT DEFAULT 'PPO',
|
||||||
|
framework TEXT DEFAULT 'SB3',
|
||||||
|
is_imitation BOOLEAN DEFAULT 0,
|
||||||
|
dataset_path TEXT,
|
||||||
|
config_json TEXT NOT NULL,
|
||||||
|
notes TEXT,
|
||||||
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_updated TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (game_id) REFERENCES roms(game_id)
|
||||||
|
ON DELETE CASCADE ON UPDATE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Run definitions linked to agents
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS runs (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
agent_id INTEGER NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
config_yaml TEXT NOT NULL,
|
||||||
|
notes TEXT,
|
||||||
|
pending BOOLEAN NOT NULL DEFAULT 1,
|
||||||
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (agent_id) REFERENCES agents(id)
|
||||||
|
ON DELETE CASCADE ON UPDATE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Model definitions (track training progress, lineage, resume state)
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS models (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
|
||||||
|
agent_id INTEGER NOT NULL,
|
||||||
|
parent_model_id INTEGER,
|
||||||
|
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
notes TEXT,
|
||||||
|
|
||||||
|
total_steps_planned INTEGER NOT NULL,
|
||||||
|
total_steps_completed INTEGER NOT NULL DEFAULT 0,
|
||||||
|
average_reward REAL,
|
||||||
|
|
||||||
|
current_learning_rate REAL,
|
||||||
|
current_clip_range REAL,
|
||||||
|
num_envs INTEGER DEFAULT 1, -- New field added here
|
||||||
|
|
||||||
|
config_patch_yaml TEXT,
|
||||||
|
checkpoint_path TEXT,
|
||||||
|
is_frozen BOOLEAN NOT NULL DEFAULT 0,
|
||||||
|
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending'
|
||||||
|
CHECK(status IN ('pending', 'training', 'paused', 'completed', 'failed')),
|
||||||
|
|
||||||
|
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
|
||||||
|
FOREIGN KEY (agent_id) REFERENCES agents(id)
|
||||||
|
ON DELETE CASCADE ON UPDATE CASCADE,
|
||||||
|
FOREIGN KEY (parent_model_id) REFERENCES models(id)
|
||||||
|
ON DELETE CASCADE ON UPDATE CASCADE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|||||||
@ -118,3 +118,379 @@ def upsert_rom_record(
|
|||||||
json.dumps(keywords or []),
|
json.dumps(keywords or []),
|
||||||
datetime.utcnow().isoformat()
|
datetime.utcnow().isoformat()
|
||||||
))
|
))
|
||||||
|
|
||||||
|
def get_agents_for_game(game_id: str) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Retrieve all agents associated with a given game_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
game_id: The unique game ID (e.g., 'sfiii3n').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of agent metadata dictionaries.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
cur = conn.execute("""
|
||||||
|
SELECT id, name, game_id, created_at, last_updated, agent_type, framework, is_imitation
|
||||||
|
FROM agents
|
||||||
|
WHERE game_id = ?
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
""", (game_id,))
|
||||||
|
return [dict(row) for row in cur.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
|
def insert_agent(
|
||||||
|
name: str,
|
||||||
|
game_id: str,
|
||||||
|
config_json: str,
|
||||||
|
agent_type: str = "PPO",
|
||||||
|
framework: str = "SB3",
|
||||||
|
is_imitation: bool = False,
|
||||||
|
dataset_path: str = "",
|
||||||
|
notes: str = ""
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Insert a new agent into the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Agent name.
|
||||||
|
game_id: Game ID this agent is associated with.
|
||||||
|
config_json: JSON string of the training config.
|
||||||
|
agent_type: Algorithm type (e.g., PPO, BC).
|
||||||
|
framework: Framework name (e.g., SB3, Ray).
|
||||||
|
is_imitation: Whether this is imitation learning.
|
||||||
|
dataset_path: Optional path to dataset (for IL).
|
||||||
|
notes: Developer notes or comments.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
conn.execute("""
|
||||||
|
INSERT INTO agents (
|
||||||
|
name,
|
||||||
|
game_id,
|
||||||
|
config_json,
|
||||||
|
agent_type,
|
||||||
|
framework,
|
||||||
|
is_imitation,
|
||||||
|
dataset_path,
|
||||||
|
notes,
|
||||||
|
created_at,
|
||||||
|
last_updated
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
name,
|
||||||
|
game_id,
|
||||||
|
config_json,
|
||||||
|
agent_type,
|
||||||
|
framework,
|
||||||
|
int(is_imitation),
|
||||||
|
dataset_path,
|
||||||
|
notes,
|
||||||
|
datetime.utcnow().isoformat(),
|
||||||
|
datetime.utcnow().isoformat()
|
||||||
|
))
|
||||||
|
|
||||||
|
def delete_agent_by_id(agent_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Permanently delete an agent from the database by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The unique agent ID to delete.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
conn.execute("DELETE FROM agents WHERE id = ?", (agent_id,))
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_by_id(agent_id: int) -> dict | None:
|
||||||
|
"""
|
||||||
|
Fetch a single agent's metadata from the database by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The unique agent ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of agent metadata if found, otherwise None.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
cur = conn.execute("""
|
||||||
|
SELECT id, name, game_id, created_at, last_updated, agent_type,
|
||||||
|
framework, is_imitation, dataset_path, config_json, notes
|
||||||
|
FROM agents
|
||||||
|
WHERE id = ?
|
||||||
|
""", (agent_id,))
|
||||||
|
row = cur.fetchone()
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return dict(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from agentm.logic.db import get_db_conn
|
||||||
|
|
||||||
|
|
||||||
|
def insert_run(
|
||||||
|
agent_id: int,
|
||||||
|
name: str,
|
||||||
|
config_yaml: str,
|
||||||
|
notes: str = ""
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Insert a new run associated with an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The agent this run belongs to.
|
||||||
|
name: The name of the run.
|
||||||
|
config_yaml: Serialized YAML config (frozen) for this run.
|
||||||
|
notes: Optional notes about the run.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
conn.execute("""
|
||||||
|
INSERT INTO runs (
|
||||||
|
agent_id,
|
||||||
|
name,
|
||||||
|
config_yaml,
|
||||||
|
notes,
|
||||||
|
created_at,
|
||||||
|
updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
agent_id,
|
||||||
|
name,
|
||||||
|
config_yaml,
|
||||||
|
notes,
|
||||||
|
datetime.utcnow().isoformat(),
|
||||||
|
datetime.utcnow().isoformat()
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def get_runs_for_agent(agent_id: int) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Get all runs associated with a specific agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The ID of the agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of run dictionaries.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
cur = conn.execute("""
|
||||||
|
SELECT * FROM runs
|
||||||
|
WHERE agent_id = ?
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
""", (agent_id,))
|
||||||
|
return [dict(row) for row in cur.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
|
def get_run_by_id(run_id: int) -> dict | None:
|
||||||
|
"""
|
||||||
|
Get a specific run by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A run dictionary or None if not found.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
cur = conn.execute("""
|
||||||
|
SELECT * FROM runs
|
||||||
|
WHERE id = ?
|
||||||
|
""", (run_id,))
|
||||||
|
row = cur.fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def delete_run_by_id(run_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Delete a run permanently from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run to delete.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
conn.execute("""
|
||||||
|
DELETE FROM runs
|
||||||
|
WHERE id = ?
|
||||||
|
""", (run_id,))
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from agentm.logic.db import get_db_conn
|
||||||
|
|
||||||
|
def insert_model(
|
||||||
|
agent_id: int,
|
||||||
|
name: str,
|
||||||
|
total_steps_planned: int,
|
||||||
|
notes: str = "",
|
||||||
|
parent_model_id: int | None = None,
|
||||||
|
average_reward: float | None = None,
|
||||||
|
current_learning_rate: float | None = None,
|
||||||
|
current_clip_range: float | None = None,
|
||||||
|
config_patch_yaml: str = "",
|
||||||
|
checkpoint_path: str = "",
|
||||||
|
is_frozen: bool = False,
|
||||||
|
status: str = "pending"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Insert a new model entry.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
conn.execute("""
|
||||||
|
INSERT INTO models (
|
||||||
|
agent_id, parent_model_id, name, notes,
|
||||||
|
total_steps_planned, total_steps_completed, average_reward,
|
||||||
|
current_learning_rate, current_clip_range,
|
||||||
|
config_patch_yaml, checkpoint_path, is_frozen,
|
||||||
|
status, created_at, updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, 0, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""", (
|
||||||
|
agent_id, parent_model_id, name, notes,
|
||||||
|
total_steps_planned, average_reward,
|
||||||
|
current_learning_rate, current_clip_range,
|
||||||
|
config_patch_yaml, checkpoint_path, int(is_frozen),
|
||||||
|
status, datetime.utcnow().isoformat(), datetime.utcnow().isoformat()
|
||||||
|
))
|
||||||
|
|
||||||
|
def get_models_for_agent(agent_id: int, include_children: bool = True) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Fetch models for a given agent.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
cur = conn.execute("""
|
||||||
|
SELECT * FROM models
|
||||||
|
WHERE agent_id = ?
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
""", (agent_id,))
|
||||||
|
return [dict(row) for row in cur.fetchall()]
|
||||||
|
|
||||||
|
def get_model_by_id(model_id: int) -> dict | None:
|
||||||
|
"""
|
||||||
|
Fetch a single model by ID.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
cur = conn.execute("SELECT * FROM models WHERE id = ?", (model_id,))
|
||||||
|
row = cur.fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
def update_model_progress(
|
||||||
|
model_id: int,
|
||||||
|
total_steps_completed: int,
|
||||||
|
average_reward: float | None = None,
|
||||||
|
current_learning_rate: float | None = None,
|
||||||
|
current_clip_range: float | None = None,
|
||||||
|
checkpoint_path: str | None = None,
|
||||||
|
status: str | None = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update step progress, learning rate, and optionally reward/checkpoint info.
|
||||||
|
"""
|
||||||
|
fields = ["total_steps_completed = ?"]
|
||||||
|
values = [total_steps_completed]
|
||||||
|
|
||||||
|
if average_reward is not None:
|
||||||
|
fields.append("average_reward = ?")
|
||||||
|
values.append(average_reward)
|
||||||
|
|
||||||
|
if current_learning_rate is not None:
|
||||||
|
fields.append("current_learning_rate = ?")
|
||||||
|
values.append(current_learning_rate)
|
||||||
|
|
||||||
|
if current_clip_range is not None:
|
||||||
|
fields.append("current_clip_range = ?")
|
||||||
|
values.append(current_clip_range)
|
||||||
|
|
||||||
|
if checkpoint_path is not None:
|
||||||
|
fields.append("checkpoint_path = ?")
|
||||||
|
values.append(checkpoint_path)
|
||||||
|
|
||||||
|
if status is not None:
|
||||||
|
fields.append("status = ?")
|
||||||
|
values.append(status)
|
||||||
|
|
||||||
|
fields.append("updated_at = ?")
|
||||||
|
values.append(datetime.utcnow().isoformat())
|
||||||
|
|
||||||
|
values.append(model_id)
|
||||||
|
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
conn.execute(f"""
|
||||||
|
UPDATE models SET {', '.join(fields)}
|
||||||
|
WHERE id = ?
|
||||||
|
""", tuple(values))
|
||||||
|
|
||||||
|
def delete_model_by_id(model_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Delete a model record.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
conn.execute("DELETE FROM models WHERE id = ?", (model_id,))
|
||||||
|
|
||||||
|
def get_models_for_run(run_id: int) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Retrieve all models linked to a specific run.
|
||||||
|
Assumes a run_id is passed and models were inserted with a reference to it (if applicable).
|
||||||
|
Currently, models are linked by agent_id only, so we filter manually.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of model dicts.
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
# First, get the agent_id for the run
|
||||||
|
cur = conn.execute("SELECT agent_id FROM runs WHERE id = ?", (run_id,))
|
||||||
|
row = cur.fetchone()
|
||||||
|
if not row:
|
||||||
|
return []
|
||||||
|
|
||||||
|
agent_id = row["agent_id"]
|
||||||
|
# Then, get all models for that agent
|
||||||
|
cur = conn.execute("""
|
||||||
|
SELECT * FROM models
|
||||||
|
WHERE agent_id = ?
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
""", (agent_id,))
|
||||||
|
return [dict(r) for r in cur.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
|
def update_run(
|
||||||
|
run_id: int,
|
||||||
|
name: str | None = None,
|
||||||
|
config_yaml: str | None = None,
|
||||||
|
notes: str | None = None,
|
||||||
|
pending: bool | None = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update a run by ID with any combination of fields.
|
||||||
|
"""
|
||||||
|
fields = []
|
||||||
|
values = []
|
||||||
|
|
||||||
|
if name is not None:
|
||||||
|
fields.append("name = ?")
|
||||||
|
values.append(name)
|
||||||
|
|
||||||
|
if config_yaml is not None:
|
||||||
|
fields.append("config_yaml = ?")
|
||||||
|
values.append(config_yaml)
|
||||||
|
|
||||||
|
if notes is not None:
|
||||||
|
fields.append("notes = ?")
|
||||||
|
values.append(notes)
|
||||||
|
|
||||||
|
if pending is not None:
|
||||||
|
fields.append("pending = ?")
|
||||||
|
values.append(int(pending))
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
return # Nothing to update
|
||||||
|
|
||||||
|
fields.append("updated_at = ?")
|
||||||
|
values.append(datetime.utcnow().isoformat())
|
||||||
|
values.append(run_id)
|
||||||
|
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
conn.execute(
|
||||||
|
f"UPDATE runs SET {', '.join(fields)} WHERE id = ?",
|
||||||
|
tuple(values)
|
||||||
|
)
|
||||||
|
|||||||
@ -12,19 +12,6 @@ Screen {
|
|||||||
border: none;
|
border: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* === Headers === */
|
|
||||||
|
|
||||||
# .header {
|
|
||||||
# dock: top;
|
|
||||||
# height: 3;
|
|
||||||
# content-align: center middle;
|
|
||||||
# background: {{SURFACE_10}};
|
|
||||||
# color: {{ACCENT}};
|
|
||||||
# text-style: bold;
|
|
||||||
# padding: 1 2;
|
|
||||||
# border: solid {{BORDER}};
|
|
||||||
# }
|
|
||||||
|
|
||||||
/* === Buttons === */
|
/* === Buttons === */
|
||||||
|
|
||||||
Button {
|
Button {
|
||||||
@ -57,6 +44,19 @@ Button:disabled {
|
|||||||
text-style: dim;
|
text-style: dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* === Danger Button Variant === */
|
||||||
|
.danger_button {
|
||||||
|
background: {{SURFACE_10}};
|
||||||
|
color: {{LIGHT}};
|
||||||
|
border: solid {{ERROR}};
|
||||||
|
}
|
||||||
|
|
||||||
|
.danger_button:hover {
|
||||||
|
background: {{SURFACE_20}};
|
||||||
|
color: {{ERROR}};
|
||||||
|
border: solid {{ERROR}};
|
||||||
|
}
|
||||||
|
|
||||||
/* === Grid Layout === */
|
/* === Grid Layout === */
|
||||||
|
|
||||||
.rom_grid {
|
.rom_grid {
|
||||||
@ -215,4 +215,111 @@ Button.game_button {
|
|||||||
align-horizontal: center;
|
align-horizontal: center;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
padding-left: 1;
|
padding-left: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* === Agent Select View === */
|
||||||
|
|
||||||
|
#agent_select_layout {
|
||||||
|
layout: vertical;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
padding: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scroll {
|
||||||
|
max-height: 35vh;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 0 1;
|
||||||
|
border: solid {{BORDER}};
|
||||||
|
scrollbar-gutter: stable;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_card {
|
||||||
|
background: {{SURFACE_10}};
|
||||||
|
border-bottom: dashed {{BORDER}};
|
||||||
|
padding: 1 2;
|
||||||
|
text-align: left;
|
||||||
|
color: {{FOREGROUND}};
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_card:last-child {
|
||||||
|
border-bottom: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_card:hover {
|
||||||
|
background: {{SURFACE_20}};
|
||||||
|
color: {{ACCENT}};
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scroll > .agent_card {
|
||||||
|
margin-bottom: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Message displayed when no agents exist (mounted manually in Python) */
|
||||||
|
#agent_scroll > Static {
|
||||||
|
color: {{DISABLED}};
|
||||||
|
text-style: italic;
|
||||||
|
padding: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#create_agent_btn {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 30%;
|
||||||
|
margin-top: 1;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
#create_agent_layout {
|
||||||
|
layout: vertical;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
padding: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_name_input {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 40%;
|
||||||
|
align: center middle;
|
||||||
|
margin: 1 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_info_panel {
|
||||||
|
layout: horizontal;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
max-height: 150vh;
|
||||||
|
padding: 0 1;
|
||||||
|
border: solid {{BORDER}};
|
||||||
|
align-vertical: top;
|
||||||
|
content-align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_info_box {
|
||||||
|
width: 70%;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_card {
|
||||||
|
max-height: 20;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_action_box {
|
||||||
|
width: 30%;
|
||||||
|
height: auto;
|
||||||
|
layout: vertical;
|
||||||
|
align-vertical: top;
|
||||||
|
align-horizontal: right;
|
||||||
|
}
|
||||||
|
|
||||||
|
#create_button_container {
|
||||||
|
align-horizontal: center;
|
||||||
|
padding-top: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* === Agent Home View === */
|
||||||
|
#agent_home_menu Button {
|
||||||
|
max-width: 40%;
|
||||||
|
min-height: 3;
|
||||||
|
margin: 1 0;
|
||||||
|
align: center middle;
|
||||||
}
|
}
|
||||||
@ -12,19 +12,6 @@ Screen {
|
|||||||
border: none;
|
border: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* === Headers === */
|
|
||||||
|
|
||||||
# .header {
|
|
||||||
# dock: top;
|
|
||||||
# height: 3;
|
|
||||||
# content-align: center middle;
|
|
||||||
# background: #282828;
|
|
||||||
# color: #ed7d3a;
|
|
||||||
# text-style: bold;
|
|
||||||
# padding: 1 2;
|
|
||||||
# border: solid #3a9bed;
|
|
||||||
# }
|
|
||||||
|
|
||||||
/* === Buttons === */
|
/* === Buttons === */
|
||||||
|
|
||||||
Button {
|
Button {
|
||||||
@ -57,6 +44,19 @@ Button:disabled {
|
|||||||
text-style: dim;
|
text-style: dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* === Danger Button Variant === */
|
||||||
|
.danger_button {
|
||||||
|
background: #282828;
|
||||||
|
color: #ffffff;
|
||||||
|
border: solid red;
|
||||||
|
}
|
||||||
|
|
||||||
|
.danger_button:hover {
|
||||||
|
background: #3f3f3f;
|
||||||
|
color: red;
|
||||||
|
border: solid red;
|
||||||
|
}
|
||||||
|
|
||||||
/* === Grid Layout === */
|
/* === Grid Layout === */
|
||||||
|
|
||||||
.rom_grid {
|
.rom_grid {
|
||||||
@ -215,4 +215,111 @@ Button.game_button {
|
|||||||
align-horizontal: center;
|
align-horizontal: center;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
padding-left: 1;
|
padding-left: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* === Agent Select View === */
|
||||||
|
|
||||||
|
#agent_select_layout {
|
||||||
|
layout: vertical;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
padding: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scroll {
|
||||||
|
max-height: 35vh;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 0 1;
|
||||||
|
border: solid #3a9bed;
|
||||||
|
scrollbar-gutter: stable;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_card {
|
||||||
|
background: #282828;
|
||||||
|
border-bottom: dashed #3a9bed;
|
||||||
|
padding: 1 2;
|
||||||
|
text-align: left;
|
||||||
|
color: #f0f0f0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_card:last-child {
|
||||||
|
border-bottom: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_card:hover {
|
||||||
|
background: #3f3f3f;
|
||||||
|
color: #ed7d3a;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_scroll > .agent_card {
|
||||||
|
margin-bottom: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Message displayed when no agents exist (mounted manually in Python) */
|
||||||
|
#agent_scroll > Static {
|
||||||
|
color: #999999;
|
||||||
|
text-style: italic;
|
||||||
|
padding: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#create_agent_btn {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 30%;
|
||||||
|
margin-top: 1;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
#create_agent_layout {
|
||||||
|
layout: vertical;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
padding: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_name_input {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 40%;
|
||||||
|
align: center middle;
|
||||||
|
margin: 1 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#agent_info_panel {
|
||||||
|
layout: horizontal;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
max-height: 150vh;
|
||||||
|
padding: 0 1;
|
||||||
|
border: solid #3a9bed;
|
||||||
|
align-vertical: top;
|
||||||
|
content-align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_info_box {
|
||||||
|
width: 70%;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_card {
|
||||||
|
max-height: 20;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent_action_box {
|
||||||
|
width: 30%;
|
||||||
|
height: auto;
|
||||||
|
layout: vertical;
|
||||||
|
align-vertical: top;
|
||||||
|
align-horizontal: right;
|
||||||
|
}
|
||||||
|
|
||||||
|
#create_button_container {
|
||||||
|
align-horizontal: center;
|
||||||
|
padding-top: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* === Agent Home View === */
|
||||||
|
#agent_home_menu Button {
|
||||||
|
max-width: 40%;
|
||||||
|
min-height: 3;
|
||||||
|
margin: 1 0;
|
||||||
|
align: center middle;
|
||||||
}
|
}
|
||||||
67
agentm/views/agent_home.py
Normal file
67
agentm/views/agent_home.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from textual.screen import Screen
|
||||||
|
from textual.widgets import Button, Static
|
||||||
|
from textual.containers import Vertical
|
||||||
|
from agentm.utils.logger import log_with_caller
|
||||||
|
from agentm.components.footer import AgentMFooter
|
||||||
|
from agentm.theme.palette import get_theme
|
||||||
|
from agentm.logic.db_functions import get_models_for_run
|
||||||
|
from agentm.views.model_select import ModelSelectView
|
||||||
|
from pyfiglet import Figlet
|
||||||
|
|
||||||
|
palette = get_theme()
|
||||||
|
|
||||||
|
class AgentHomeView(Screen):
|
||||||
|
BINDINGS = [("escape", "app.pop_screen", "Back")]
|
||||||
|
|
||||||
|
def __init__(self, agent: dict, run: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.agent_metadata = agent
|
||||||
|
self.run_metadata = run
|
||||||
|
self.models_exist = False
|
||||||
|
|
||||||
|
def compose(self):
|
||||||
|
f = Figlet(font="ansi_shadow")
|
||||||
|
ascii_header = f.renderText(self.agent_metadata["name"])
|
||||||
|
|
||||||
|
yield Static(f"[{palette.ACCENT}]{ascii_header}[/{palette.ACCENT}]", classes="header", expand=False)
|
||||||
|
yield Static(
|
||||||
|
f"[b]Agent ID:[/] {self.agent_metadata['id']} | "
|
||||||
|
f"[b]Run:[/] {self.run_metadata['name']}",
|
||||||
|
classes="subheader",
|
||||||
|
expand=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Buttons - only Eval and Submit may be disabled
|
||||||
|
self.train_btn = Button("🚀 Train", id="train_btn", classes="confirm_button")
|
||||||
|
self.eval_btn = Button("🎮 Eval", id="eval_btn", classes="confirm_button", disabled=True)
|
||||||
|
self.submit_btn = Button("📦 Submit", id="submit_btn", classes="confirm_button", disabled=True)
|
||||||
|
|
||||||
|
yield Vertical(
|
||||||
|
self.train_btn,
|
||||||
|
self.eval_btn,
|
||||||
|
self.submit_btn,
|
||||||
|
id="agent_home_menu",
|
||||||
|
classes="centered_layout"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield AgentMFooter(compact=True)
|
||||||
|
|
||||||
|
async def on_mount(self):
|
||||||
|
log_with_caller("info", f"Entered AgentHomeView for agent '{self.agent_metadata['name']}'")
|
||||||
|
models = get_models_for_run(self.run_metadata["id"])
|
||||||
|
self.models_exist = len(models) > 0
|
||||||
|
self.eval_btn.disabled = not self.models_exist
|
||||||
|
self.submit_btn.disabled = not self.models_exist
|
||||||
|
log_with_caller("debug", f"Found {len(models)} model(s) for run '{self.run_metadata['name']}'")
|
||||||
|
|
||||||
|
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
action = event.button.id.replace("_btn", "") # 'train', 'eval', or 'submit'
|
||||||
|
log_with_caller("info", f"{action.title()} selected for run '{self.run_metadata['name']}'")
|
||||||
|
|
||||||
|
# Always go to model select before action
|
||||||
|
await self.app.push_screen(
|
||||||
|
ModelSelectView(
|
||||||
|
agent_metadata=self.agent_metadata,
|
||||||
|
mode=action
|
||||||
|
)
|
||||||
|
)
|
||||||
166
agentm/views/agent_select.py
Normal file
166
agentm/views/agent_select.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
from textual.screen import Screen
|
||||||
|
from textual.containers import Vertical, Horizontal, VerticalScroll
|
||||||
|
from textual.widgets import Static, Button
|
||||||
|
from agentm.logic.db_functions import get_agents_for_game, delete_agent_by_id
|
||||||
|
from agentm.theme.palette import get_theme
|
||||||
|
from agentm.utils.logger import log_with_caller
|
||||||
|
from agentm.components.footer import AgentMFooter
|
||||||
|
from .create_agent import CreateAgentView
|
||||||
|
from .select_run import SelectRunView
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
palette = get_theme()
|
||||||
|
|
||||||
|
class AgentCard(Static):
|
||||||
|
def __init__(self, agent_data: dict, parent_view):
|
||||||
|
super().__init__(classes="agent_card")
|
||||||
|
self.agent_data = agent_data
|
||||||
|
self.parent_view = parent_view
|
||||||
|
|
||||||
|
def render(self) -> str:
|
||||||
|
return f"""
|
||||||
|
[bold {palette.ACCENT}]{self.agent_data['name']}[/]
|
||||||
|
[dim]ID:[/] {self.agent_data['id']}
|
||||||
|
[dim]Created:[/] {self.agent_data.get('created_at', '—')}
|
||||||
|
[dim]Game ID:[/] {self.agent_data['game_id']}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
async def on_click(self):
|
||||||
|
await self.parent_view.display_agent_info(self.agent_data)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSelectView(Screen):
|
||||||
|
BINDINGS = [
|
||||||
|
("escape", "app.pop_screen", "Back"),
|
||||||
|
("n", "create_agent", "New Agent"),
|
||||||
|
("r", "refresh_agents", "Refresh"),
|
||||||
|
("up", "scroll_up", "Scroll Up"),
|
||||||
|
("down", "scroll_down", "Scroll Down"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, game_metadata: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.game_metadata = game_metadata
|
||||||
|
self.selected_agent = None
|
||||||
|
self.delete_mode = False
|
||||||
|
|
||||||
|
def compose(self):
|
||||||
|
ascii_path = Path(__file__).parent.parent / "assets" / "headers" / "agent_select.txt"
|
||||||
|
try:
|
||||||
|
header_text = ascii_path.read_text()
|
||||||
|
except FileNotFoundError:
|
||||||
|
header_text = "=== SELECT AGENT ==="
|
||||||
|
|
||||||
|
self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header")
|
||||||
|
self.subheader = Static(f"[b]{self.game_metadata['title']}[/b]", classes="subheader")
|
||||||
|
|
||||||
|
self.agent_list = VerticalScroll(id="agent_scroll")
|
||||||
|
self.agent_info = Horizontal(id="agent_info_panel")
|
||||||
|
|
||||||
|
self.confirm_button = Button("✅ Select Agent", id="confirm_agent_btn", classes="confirm_button", disabled=True)
|
||||||
|
self.delete_button = Button("🗑️ Delete Agent", id="delete_agent_btn", classes="danger_button", disabled=True)
|
||||||
|
self.create_button = Button("➕ Create Agent", id="create_agent_btn", classes="confirm_button")
|
||||||
|
|
||||||
|
yield Vertical(
|
||||||
|
self.header,
|
||||||
|
self.subheader,
|
||||||
|
self.agent_list,
|
||||||
|
self.agent_info,
|
||||||
|
self.create_button,
|
||||||
|
AgentMFooter(compact=True),
|
||||||
|
id="agent_select_layout"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_mount(self):
|
||||||
|
log_with_caller("debug", f"Mounted AgentSelectView for game_id: {self.game_metadata['game_id']}")
|
||||||
|
await self.refresh_agent_list()
|
||||||
|
|
||||||
|
async def on_resume(self):
|
||||||
|
await self.refresh_agent_list()
|
||||||
|
|
||||||
|
async def refresh_agent_list(self):
|
||||||
|
try:
|
||||||
|
for child in list(self.agent_list.children):
|
||||||
|
await child.remove()
|
||||||
|
|
||||||
|
agents = get_agents_for_game(self.game_metadata["game_id"])
|
||||||
|
log_with_caller("info", f"Refreshed: {len(agents)} agents found for {self.game_metadata['game_id']}")
|
||||||
|
|
||||||
|
if agents:
|
||||||
|
for agent in agents:
|
||||||
|
await self.agent_list.mount(AgentCard(agent, self))
|
||||||
|
else:
|
||||||
|
await self.agent_list.mount(Static("[dim]No agents found for this game.[/dim]"))
|
||||||
|
|
||||||
|
self.selected_agent = None
|
||||||
|
self.delete_mode = False
|
||||||
|
await self.display_agent_info(None)
|
||||||
|
except Exception as e:
|
||||||
|
log_with_caller("error", f"Error rendering agent list: {e}")
|
||||||
|
|
||||||
|
async def display_agent_info(self, agent: dict | None):
|
||||||
|
self.selected_agent = agent
|
||||||
|
self.delete_mode = False
|
||||||
|
await self.agent_info.remove_children()
|
||||||
|
|
||||||
|
if not agent:
|
||||||
|
await self.agent_info.mount(Static("[dim]Select an agent to view details[/dim]"))
|
||||||
|
self.confirm_button.disabled = True
|
||||||
|
self.delete_button.disabled = True
|
||||||
|
return
|
||||||
|
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
table = Table.grid(padding=(0, 1))
|
||||||
|
table.add_column("Key", style="bold underline")
|
||||||
|
table.add_column("Value", style=palette.ACCENT, overflow="fold")
|
||||||
|
|
||||||
|
table.add_row("Name", agent["name"])
|
||||||
|
table.add_row("Agent ID", str(agent["id"]))
|
||||||
|
table.add_row("Game ID", agent["game_id"])
|
||||||
|
table.add_row("Created", agent.get("created_at", "—"))
|
||||||
|
table.add_row("Type", agent.get("agent_type", "—"))
|
||||||
|
table.add_row("Framework", agent.get("framework", "—"))
|
||||||
|
table.add_row("Imitation", str(agent.get("is_imitation", False)))
|
||||||
|
|
||||||
|
info_panel = Static(Panel(table, title="Agent Info", border_style=palette.BORDER), classes="agent_info_box")
|
||||||
|
button_column = Vertical(self.confirm_button, self.delete_button, classes="agent_action_box")
|
||||||
|
|
||||||
|
await self.agent_info.mount(info_panel, button_column)
|
||||||
|
self.confirm_button.disabled = False
|
||||||
|
self.delete_button.disabled = False
|
||||||
|
self.delete_button.label = "🗑️ Delete Agent"
|
||||||
|
|
||||||
|
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
if event.button.id == "create_agent_btn":
|
||||||
|
log_with_caller("info", "Create Agent button pressed.")
|
||||||
|
await self.app.push_screen(CreateAgentView(self.game_metadata))
|
||||||
|
|
||||||
|
elif event.button.id == "confirm_agent_btn" and self.selected_agent:
|
||||||
|
log_with_caller("info", f"Agent confirmed: {self.selected_agent['name']}")
|
||||||
|
await self.app.push_screen(SelectRunView(self.selected_agent))
|
||||||
|
|
||||||
|
elif event.button.id == "delete_agent_btn" and self.selected_agent:
|
||||||
|
if not self.delete_mode:
|
||||||
|
self.delete_mode = True
|
||||||
|
self.delete_button.label = "❌ Confirm Delete"
|
||||||
|
log_with_caller("warning", f"Pending deletion for agent: {self.selected_agent['name']}")
|
||||||
|
else:
|
||||||
|
log_with_caller("info", f"Deleting agent: {self.selected_agent['name']} ({self.selected_agent['id']})")
|
||||||
|
delete_agent_by_id(self.selected_agent["id"])
|
||||||
|
await self.refresh_agent_list()
|
||||||
|
|
||||||
|
async def action_create_agent(self):
|
||||||
|
await self.app.push_screen(CreateAgentView(self.game_metadata))
|
||||||
|
|
||||||
|
async def action_refresh_agents(self):
|
||||||
|
await self.refresh_agent_list()
|
||||||
|
|
||||||
|
async def action_scroll_up(self):
|
||||||
|
scroll = self.query_one("#agent_scroll", VerticalScroll)
|
||||||
|
scroll.scroll_up()
|
||||||
|
|
||||||
|
async def action_scroll_down(self):
|
||||||
|
scroll = self.query_one("#agent_scroll", VerticalScroll)
|
||||||
|
scroll.scroll_down()
|
||||||
61
agentm/views/create_agent.py
Normal file
61
agentm/views/create_agent.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from textual.screen import Screen
|
||||||
|
from textual.containers import Vertical
|
||||||
|
from textual.widgets import Static, Button, Input
|
||||||
|
from agentm.components.footer import AgentMFooter
|
||||||
|
from agentm.theme.palette import get_theme
|
||||||
|
from agentm.utils.logger import log_with_caller
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from agentm.logic.db_functions import insert_agent # ✅ Add this import
|
||||||
|
|
||||||
|
palette = get_theme()
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAgentView(Screen):
|
||||||
|
BINDINGS = [("escape", "app.pop_screen", "Back")]
|
||||||
|
|
||||||
|
def __init__(self, game_metadata: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.game_metadata = game_metadata
|
||||||
|
|
||||||
|
def compose(self):
|
||||||
|
# Load ASCII header
|
||||||
|
ascii_path = Path(__file__).parent.parent / "assets" / "headers" / "create_agent.txt"
|
||||||
|
try:
|
||||||
|
header_text = ascii_path.read_text()
|
||||||
|
except FileNotFoundError:
|
||||||
|
header_text = "### Create Agent ###"
|
||||||
|
|
||||||
|
self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header")
|
||||||
|
|
||||||
|
self.name_input = Input(placeholder="Enter agent name...", id="agent_name_input")
|
||||||
|
self.create_button = Button("🚀 Create Agent", id="create_button", classes="confirm_button")
|
||||||
|
|
||||||
|
yield Vertical(
|
||||||
|
self.header,
|
||||||
|
self.name_input,
|
||||||
|
self.create_button,
|
||||||
|
AgentMFooter(compact=True),
|
||||||
|
id="create_agent_layout",
|
||||||
|
classes="centered_layout"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
if event.button.id == "create_button":
|
||||||
|
name = self.name_input.value.strip()
|
||||||
|
|
||||||
|
if name:
|
||||||
|
log_with_caller("info", f"Creating agent '{name}' for game_id: {self.game_metadata['game_id']}")
|
||||||
|
|
||||||
|
# ✅ Insert the agent into the DB
|
||||||
|
insert_agent(
|
||||||
|
name=name,
|
||||||
|
game_id=self.game_metadata["game_id"],
|
||||||
|
config_json="{}", # Replace with actual config later
|
||||||
|
notes="Created via CreateAgentView"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ✅ Dismiss screen
|
||||||
|
await self.app.pop_screen()
|
||||||
|
else:
|
||||||
|
log_with_caller("warning", "Tried to create agent with empty name.")
|
||||||
9
agentm/views/evaluation.py
Normal file
9
agentm/views/evaluation.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from textual.screen import Screen
|
||||||
|
from textual.widgets import Static
|
||||||
|
from agentm.theme.palette import get_theme
|
||||||
|
|
||||||
|
palette = get_theme()
|
||||||
|
|
||||||
|
class EvaluationView(Screen):
|
||||||
|
def compose(self):
|
||||||
|
yield Static(f"[{palette.ACCENT}]Evaluation View Placeholder[/]", classes="centered_layout")
|
||||||
@ -15,6 +15,8 @@ from agentm.logic.roms import get_verified_roms, GAME_FILES
|
|||||||
from agentm.theme.palette import get_theme
|
from agentm.theme.palette import get_theme
|
||||||
from agentm.components.footer import AgentMFooter
|
from agentm.components.footer import AgentMFooter
|
||||||
|
|
||||||
|
from .agent_select import AgentSelectView
|
||||||
|
|
||||||
palette = get_theme()
|
palette = get_theme()
|
||||||
|
|
||||||
|
|
||||||
@ -72,7 +74,14 @@ class GameCardButton(Button):
|
|||||||
|
|
||||||
|
|
||||||
class HomeView(Screen):
|
class HomeView(Screen):
|
||||||
BINDINGS = [("escape", "app.quit", "Quit")]
|
BINDINGS = [
|
||||||
|
("escape", "app.quit", "Quit"),
|
||||||
|
("up", "scroll_up", "Scroll Up"),
|
||||||
|
("down", "scroll_down", "Scroll Down"),
|
||||||
|
("enter", "confirm_game", "Confirm Game"),
|
||||||
|
("r", "refresh_roms", "Refresh ROMs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def highlight_selected(self, selected_widget: GameCardButton):
|
def highlight_selected(self, selected_widget: GameCardButton):
|
||||||
for card in self.rom_grid.children:
|
for card in self.rom_grid.children:
|
||||||
@ -195,4 +204,24 @@ class HomeView(Screen):
|
|||||||
|
|
||||||
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
if event.button.id == "confirm_button" and self.selected_game:
|
if event.button.id == "confirm_button" and self.selected_game:
|
||||||
await self.app.push_screen("training", self.selected_game)
|
await self.app.push_screen(AgentSelectView(self.selected_game))
|
||||||
|
|
||||||
|
|
||||||
|
async def action_scroll_up(self):
|
||||||
|
scroll = self.query_one("#rom_grid_scroll", VerticalScroll)
|
||||||
|
scroll.scroll_up()
|
||||||
|
|
||||||
|
async def action_scroll_down(self):
|
||||||
|
scroll = self.query_one("#rom_grid_scroll", VerticalScroll)
|
||||||
|
scroll.scroll_down()
|
||||||
|
|
||||||
|
async def action_confirm_game(self):
|
||||||
|
if self.selected_game:
|
||||||
|
log_with_caller("info", f"Keyboard: Confirming game {self.selected_game['title']}")
|
||||||
|
await self.app.push_screen(AgentSelectView(self.selected_game))
|
||||||
|
else:
|
||||||
|
log_with_caller("warning", "Keyboard: Tried to confirm with no game selected.")
|
||||||
|
|
||||||
|
async def action_refresh_roms(self):
|
||||||
|
log_with_caller("info", "Keyboard: Refreshing ROM list")
|
||||||
|
self.run_worker(self.run_verification, thread=True, exclusive=True, name="rom-verification")
|
||||||
146
agentm/views/model_select.py
Normal file
146
agentm/views/model_select.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
from typing import Literal
|
||||||
|
from textual.screen import Screen
|
||||||
|
from textual.containers import Vertical, Horizontal, VerticalScroll
|
||||||
|
from textual.widgets import Static, Button
|
||||||
|
from agentm.logic.db_functions import get_models_for_agent
|
||||||
|
from agentm.theme.palette import get_theme
|
||||||
|
from agentm.utils.logger import log_with_caller
|
||||||
|
from agentm.components.footer import AgentMFooter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Import destination views (these are placeholders; replace as needed)
|
||||||
|
from agentm.views.training import TrainingView
|
||||||
|
from agentm.views.evaluation import EvaluationView
|
||||||
|
from agentm.views.submission import SubmissionView
|
||||||
|
|
||||||
|
# Ensure the palette is loaded
|
||||||
|
palette = get_theme()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCard(Static):
|
||||||
|
def __init__(self, model_data: dict, parent_view):
|
||||||
|
super().__init__(classes="agent_card")
|
||||||
|
self.model_data = model_data
|
||||||
|
self.parent_view = parent_view
|
||||||
|
|
||||||
|
def render(self) -> str:
|
||||||
|
return f"""
|
||||||
|
[bold {palette.ACCENT}]{self.model_data['name']}[/]
|
||||||
|
[dim]Steps:[/] {self.model_data['total_steps_completed']} / {self.model_data['total_steps_planned']}
|
||||||
|
[dim]Avg Reward:[/] {self.model_data.get('average_reward', '—')}
|
||||||
|
[dim]Status:[/] {self.model_data['status']}
|
||||||
|
[dim]Created:[/] {self.model_data.get('created_at', '—')}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
async def on_click(self):
|
||||||
|
await self.parent_view.display_model_info(self.model_data)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSelectView(Screen):
|
||||||
|
BINDINGS = [
|
||||||
|
("escape", "app.pop_screen", "Back"),
|
||||||
|
("r", "refresh_models", "Refresh"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"]):
|
||||||
|
super().__init__()
|
||||||
|
self.agent_metadata = agent_metadata
|
||||||
|
self.mode = mode
|
||||||
|
self.selected_model = None
|
||||||
|
|
||||||
|
def compose(self):
|
||||||
|
header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_model.txt"
|
||||||
|
try:
|
||||||
|
header_text = header_path.read_text()
|
||||||
|
except FileNotFoundError:
|
||||||
|
header_text = "=== SELECT MODEL ==="
|
||||||
|
|
||||||
|
self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header")
|
||||||
|
self.subheader = Static(f"[b]{self.agent_metadata['name']}[/b]", classes="subheader")
|
||||||
|
|
||||||
|
self.model_list = VerticalScroll(id="agent_scroll")
|
||||||
|
self.model_info = Horizontal(id="agent_info_panel")
|
||||||
|
|
||||||
|
self.select_button = Button("✅ Select Model", id="select_model_btn", classes="confirm_button", disabled=True)
|
||||||
|
self.create_button = Button("➕ Create New Model", id="create_model_btn", classes="confirm_button")
|
||||||
|
|
||||||
|
yield Vertical(
|
||||||
|
self.header,
|
||||||
|
self.subheader,
|
||||||
|
self.model_list,
|
||||||
|
self.model_info,
|
||||||
|
self.select_button,
|
||||||
|
self.create_button,
|
||||||
|
AgentMFooter(compact=True),
|
||||||
|
id="agent_select_layout"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_mount(self):
|
||||||
|
log_with_caller("debug", f"Mounted ModelSelectView for agent_id={self.agent_metadata['id']} mode={self.mode}")
|
||||||
|
await self.refresh_model_list()
|
||||||
|
|
||||||
|
async def refresh_model_list(self):
|
||||||
|
try:
|
||||||
|
for child in list(self.model_list.children):
|
||||||
|
await child.remove()
|
||||||
|
|
||||||
|
models = get_models_for_agent(self.agent_metadata["id"])
|
||||||
|
log_with_caller("info", f"Refreshed: {len(models)} models found for agent {self.agent_metadata['name']}")
|
||||||
|
|
||||||
|
if models:
|
||||||
|
for model in models:
|
||||||
|
await self.model_list.mount(ModelCard(model, self))
|
||||||
|
else:
|
||||||
|
await self.model_list.mount(Static("[dim]No models found for this agent.[/dim]"))
|
||||||
|
|
||||||
|
self.selected_model = None
|
||||||
|
await self.display_model_info(None)
|
||||||
|
except Exception as e:
|
||||||
|
log_with_caller("error", f"Error rendering model list: {e}")
|
||||||
|
|
||||||
|
async def display_model_info(self, model: dict | None):
|
||||||
|
self.selected_model = model
|
||||||
|
await self.model_info.remove_children()
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
await self.model_info.mount(Static("[dim]Select a model to view details[/dim]"))
|
||||||
|
self.select_button.disabled = True
|
||||||
|
return
|
||||||
|
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
table = Table.grid(padding=(0, 1))
|
||||||
|
table.add_column("Key", style="bold underline")
|
||||||
|
table.add_column("Value", style=palette.ACCENT, overflow="fold")
|
||||||
|
|
||||||
|
table.add_row("Name", model["name"])
|
||||||
|
table.add_row("Status", model["status"])
|
||||||
|
table.add_row("Steps", f"{model['total_steps_completed']} / {model['total_steps_planned']}")
|
||||||
|
table.add_row("Avg Reward", str(model.get("average_reward", "—")))
|
||||||
|
table.add_row("Learning Rate", str(model.get("current_learning_rate", "—")))
|
||||||
|
table.add_row("Clip Range", str(model.get("current_clip_range", "—")))
|
||||||
|
table.add_row("Checkpoint", model.get("checkpoint_path", "—"))
|
||||||
|
table.add_row("Created", model.get("created_at", "—"))
|
||||||
|
|
||||||
|
info_panel = Static(Panel(table, title="Model Info", border_style=palette.BORDER), classes="agent_info_box")
|
||||||
|
await self.model_info.mount(info_panel)
|
||||||
|
self.select_button.disabled = False
|
||||||
|
|
||||||
|
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
if event.button.id == "select_model_btn" and self.selected_model:
|
||||||
|
log_with_caller("info", f"Model confirmed: {self.selected_model['name']} for mode={self.mode}")
|
||||||
|
match self.mode:
|
||||||
|
case "train":
|
||||||
|
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=self.selected_model))
|
||||||
|
case "eval":
|
||||||
|
await self.app.push_screen(EvaluationView(agent=self.agent_metadata, model=self.selected_model))
|
||||||
|
case "submit":
|
||||||
|
await self.app.push_screen(SubmissionView(agent=self.agent_metadata, model=self.selected_model))
|
||||||
|
|
||||||
|
elif event.button.id == "create_model_btn":
|
||||||
|
log_with_caller("info", f"Creating new model for agent: {self.agent_metadata['name']}")
|
||||||
|
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None)) # New model starts from training
|
||||||
|
|
||||||
|
async def action_refresh_models(self):
|
||||||
|
await self.refresh_model_list()
|
||||||
147
agentm/views/select_run.py
Normal file
147
agentm/views/select_run.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
from textual.screen import Screen
|
||||||
|
from textual.containers import Vertical, Horizontal, VerticalScroll
|
||||||
|
from textual.widgets import Static, Button
|
||||||
|
from agentm.logic.db_functions import get_runs_for_agent, insert_run
|
||||||
|
from agentm.theme.palette import get_theme
|
||||||
|
from agentm.utils.logger import log_with_caller
|
||||||
|
from agentm.components.footer import AgentMFooter
|
||||||
|
from agentm.views.agent_home import AgentHomeView
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
palette = get_theme()
|
||||||
|
|
||||||
|
|
||||||
|
class RunCard(Static):
|
||||||
|
def __init__(self, run_data: dict, parent_view):
|
||||||
|
super().__init__(classes="agent_card")
|
||||||
|
self.run_data = run_data
|
||||||
|
self.parent_view = parent_view
|
||||||
|
|
||||||
|
def render(self) -> str:
|
||||||
|
return f"""
|
||||||
|
[bold {palette.ACCENT}]{self.run_data['name']}[/]
|
||||||
|
[dim]Created:[/] {self.run_data.get('created_at', '—')}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
async def on_click(self):
|
||||||
|
await self.parent_view.display_run_info(self.run_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SelectRunView(Screen):
|
||||||
|
BINDINGS = [
|
||||||
|
("escape", "app.pop_screen", "Back"),
|
||||||
|
("r", "refresh_runs", "Refresh"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, agent_metadata: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.agent_metadata = agent_metadata
|
||||||
|
self.selected_run = None
|
||||||
|
|
||||||
|
def compose(self):
|
||||||
|
header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_run.txt"
|
||||||
|
try:
|
||||||
|
header_text = header_path.read_text()
|
||||||
|
except FileNotFoundError:
|
||||||
|
header_text = "=== SELECT RUN ==="
|
||||||
|
|
||||||
|
self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header")
|
||||||
|
self.subheader = Static(f"[b]{self.agent_metadata['name']}[/b]", classes="subheader")
|
||||||
|
|
||||||
|
self.run_list = VerticalScroll(id="agent_scroll")
|
||||||
|
self.run_info = Horizontal(id="agent_info_panel")
|
||||||
|
|
||||||
|
self.select_button = Button("✅ Select Run", id="select_run_btn", classes="confirm_button", disabled=True)
|
||||||
|
self.create_button = Button("➕ Create New Run", id="create_run_btn", classes="confirm_button")
|
||||||
|
|
||||||
|
yield Vertical(
|
||||||
|
self.header,
|
||||||
|
self.subheader,
|
||||||
|
self.run_list,
|
||||||
|
self.run_info,
|
||||||
|
self.select_button,
|
||||||
|
self.create_button,
|
||||||
|
AgentMFooter(compact=True),
|
||||||
|
id="agent_select_layout"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_mount(self):
|
||||||
|
log_with_caller("debug", f"Mounted SelectRunView for agent_id: {self.agent_metadata['id']}")
|
||||||
|
await self.refresh_run_list()
|
||||||
|
|
||||||
|
async def refresh_run_list(self):
|
||||||
|
try:
|
||||||
|
for child in list(self.run_list.children):
|
||||||
|
await child.remove()
|
||||||
|
|
||||||
|
runs = get_runs_for_agent(self.agent_metadata["id"])
|
||||||
|
log_with_caller("info", f"Refreshed: {len(runs)} runs found for agent {self.agent_metadata['name']}")
|
||||||
|
|
||||||
|
if runs:
|
||||||
|
for run in runs:
|
||||||
|
await self.run_list.mount(RunCard(run, self))
|
||||||
|
else:
|
||||||
|
await self.run_list.mount(Static("[dim]No runs found for this agent.[/dim]"))
|
||||||
|
|
||||||
|
self.selected_run = None
|
||||||
|
await self.display_run_info(None)
|
||||||
|
except Exception as e:
|
||||||
|
log_with_caller("error", f"Error rendering run list: {e}")
|
||||||
|
|
||||||
|
async def display_run_info(self, run: dict | None):
|
||||||
|
self.selected_run = run
|
||||||
|
await self.run_info.remove_children()
|
||||||
|
|
||||||
|
if not run:
|
||||||
|
await self.run_info.mount(Static("[dim]Select a run to view details[/dim]"))
|
||||||
|
self.select_button.disabled = True
|
||||||
|
return
|
||||||
|
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
table = Table.grid(padding=(0, 1))
|
||||||
|
table.add_column("Key", style="bold underline")
|
||||||
|
table.add_column("Value", style=palette.ACCENT, overflow="fold")
|
||||||
|
|
||||||
|
table.add_row("Name", run["name"])
|
||||||
|
table.add_row("Created", run.get("created_at", "—"))
|
||||||
|
table.add_row("Notes", run.get("notes", "—"))
|
||||||
|
|
||||||
|
info_panel = Static(Panel(table, title="Run Info", border_style=palette.BORDER), classes="agent_info_box")
|
||||||
|
await self.run_info.mount(info_panel)
|
||||||
|
self.select_button.disabled = False
|
||||||
|
|
||||||
|
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
if event.button.id == "select_run_btn" and self.selected_run:
|
||||||
|
log_with_caller("info", f"Run confirmed: {self.selected_run['name']}")
|
||||||
|
await self.app.push_screen(AgentHomeView(agent=self.agent_metadata, run=self.selected_run))
|
||||||
|
|
||||||
|
elif event.button.id == "create_run_btn":
|
||||||
|
# Auto-generate name and basic config
|
||||||
|
run_name = f"{self.agent_metadata['name']} - {datetime.utcnow().strftime('%Y%m%d-%H%M%S')}"
|
||||||
|
config_yaml = "# Auto-generated config for new run\n"
|
||||||
|
notes = f"Initial run for agent {self.agent_metadata['name']}"
|
||||||
|
|
||||||
|
# Insert into DB
|
||||||
|
insert_run(
|
||||||
|
agent_id=self.agent_metadata["id"],
|
||||||
|
name=run_name,
|
||||||
|
config_yaml=config_yaml,
|
||||||
|
notes=notes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch the new run and push to AgentHomeView
|
||||||
|
runs = get_runs_for_agent(self.agent_metadata["id"])
|
||||||
|
new_run = runs[0] if runs else None
|
||||||
|
|
||||||
|
if new_run:
|
||||||
|
log_with_caller("info", f"Created and selected new run: {new_run['name']}")
|
||||||
|
await self.app.push_screen(AgentHomeView(agent=self.agent_metadata, run=new_run))
|
||||||
|
else:
|
||||||
|
log_with_caller("error", "Failed to retrieve newly created run")
|
||||||
|
|
||||||
|
async def action_refresh_runs(self):
|
||||||
|
await self.refresh_run_list()
|
||||||
9
agentm/views/submission.py
Normal file
9
agentm/views/submission.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from textual.screen import Screen
|
||||||
|
from textual.widgets import Static
|
||||||
|
from agentm.theme.palette import get_theme
|
||||||
|
|
||||||
|
palette = get_theme()
|
||||||
|
|
||||||
|
class SubmissionView(Screen):
|
||||||
|
def compose(self):
|
||||||
|
yield Static(f"[{palette.ACCENT}]Submission View Placeholder[/]", classes="centered_layout")
|
||||||
@ -0,0 +1,106 @@
|
|||||||
|
from textual.screen import Screen
|
||||||
|
from textual.widgets import Static, Input, Button
|
||||||
|
from textual.containers import Vertical, Horizontal
|
||||||
|
from agentm.theme.palette import get_theme
|
||||||
|
from agentm.utils.logger import log_with_caller
|
||||||
|
from agentm.logic.db_functions import insert_model, update_run_pending
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
palette = get_theme()
|
||||||
|
|
||||||
|
class TrainingView(Screen):
|
||||||
|
BINDINGS = [
|
||||||
|
("escape", "app.pop_screen", "Back")
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, agent: dict, model: dict | None, run: dict):
|
||||||
|
super().__init__()
|
||||||
|
self.agent_metadata = agent
|
||||||
|
self.model_metadata = model
|
||||||
|
self.run_metadata = run
|
||||||
|
self.is_pending_run = run.get("pending", True)
|
||||||
|
|
||||||
|
def compose(self):
|
||||||
|
log_with_caller("info", f"Opening TrainingView for agent='{self.agent_metadata['name']}', run='{self.run_metadata['name']}', pending={self.is_pending_run}")
|
||||||
|
|
||||||
|
yield Static(f"[{palette.ACCENT} bold]Training: {self.agent_metadata['name']}[/]", classes="header")
|
||||||
|
|
||||||
|
if self.is_pending_run:
|
||||||
|
yield self.render_pending_setup()
|
||||||
|
else:
|
||||||
|
yield self.render_resume_options()
|
||||||
|
|
||||||
|
def render_pending_setup(self):
|
||||||
|
yield Static("[b]Initial Model Setup[/b]", classes="subheader")
|
||||||
|
|
||||||
|
self.name_input = Input(placeholder="Model Name", id="model_name")
|
||||||
|
self.steps_input = Input(placeholder="Total Training Steps", id="total_steps")
|
||||||
|
self.lr_input = Input(placeholder="Initial Learning Rate", id="learning_rate")
|
||||||
|
self.clip_input = Input(placeholder="Initial Clip Range", id="clip_range")
|
||||||
|
self.notes_input = Input(placeholder="Notes", id="notes")
|
||||||
|
|
||||||
|
self.confirm_button = Button("✅ Create & Start Training", id="confirm_model_btn", classes="confirm_button")
|
||||||
|
|
||||||
|
return Vertical(
|
||||||
|
self.name_input,
|
||||||
|
self.steps_input,
|
||||||
|
self.lr_input,
|
||||||
|
self.clip_input,
|
||||||
|
self.notes_input,
|
||||||
|
self.confirm_button,
|
||||||
|
classes="centered_layout"
|
||||||
|
)
|
||||||
|
|
||||||
|
def render_resume_options(self):
|
||||||
|
model = self.model_metadata
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
table = Table.grid(padding=(0, 1))
|
||||||
|
table.add_column("Key", style="bold underline")
|
||||||
|
table.add_column("Value", style=palette.ACCENT, overflow="fold")
|
||||||
|
|
||||||
|
table.add_row("Name", model["name"])
|
||||||
|
table.add_row("Steps", f"{model['total_steps_completed']} / {model['total_steps_planned']}")
|
||||||
|
table.add_row("Reward", str(model.get("average_reward", "—")))
|
||||||
|
table.add_row("Learning Rate", str(model.get("current_learning_rate", "—")))
|
||||||
|
table.add_row("Clip Range", str(model.get("current_clip_range", "—")))
|
||||||
|
table.add_row("Created", model.get("created_at", "—"))
|
||||||
|
|
||||||
|
return Vertical(
|
||||||
|
Static(Panel(table, title="Model Info", border_style=palette.BORDER), classes="agent_info_box"),
|
||||||
|
Button("⏯️ Resume Training", id="resume_training_btn", classes="confirm_button"),
|
||||||
|
classes="centered_layout"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
if event.button.id == "confirm_model_btn":
|
||||||
|
try:
|
||||||
|
name = self.name_input.value.strip()
|
||||||
|
total_steps = int(self.steps_input.value)
|
||||||
|
learning_rate = float(self.lr_input.value)
|
||||||
|
clip_range = float(self.clip_input.value)
|
||||||
|
notes = self.notes_input.value.strip()
|
||||||
|
|
||||||
|
log_with_caller("info", f"Creating new model: {name} for agent_id={self.agent_metadata['id']}")
|
||||||
|
|
||||||
|
insert_model(
|
||||||
|
agent_id=self.agent_metadata["id"],
|
||||||
|
name=name,
|
||||||
|
total_steps_planned=total_steps,
|
||||||
|
current_learning_rate=learning_rate,
|
||||||
|
current_clip_range=clip_range,
|
||||||
|
notes=notes
|
||||||
|
)
|
||||||
|
|
||||||
|
update_run_pending(self.run_metadata["id"], False)
|
||||||
|
log_with_caller("info", f"Model '{name}' created and run marked as not pending")
|
||||||
|
|
||||||
|
await self.app.pop_screen()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log_with_caller("error", f"Failed to create model: {e}")
|
||||||
|
|
||||||
|
elif event.button.id == "resume_training_btn":
|
||||||
|
log_with_caller("info", f"Resuming training for model: {self.model_metadata['name']}")
|
||||||
|
await self.app.pop_screen()
|
||||||
@ -9,4 +9,5 @@ tensorboard
|
|||||||
requests
|
requests
|
||||||
sqlite3
|
sqlite3
|
||||||
rich-pixels
|
rich-pixels
|
||||||
pillow
|
pillow
|
||||||
|
pyfiglet
|
||||||
Loading…
Reference in New Issue
Block a user