agent_m/agentm/logic/db_functions.py
mscrnt 4500bfd388 Refactor agent management views and styles
- Removed commented-out header styles from styles.base.tcss and styles.tcss.
- Added new styles for danger buttons and agent selection views in styles.base.tcss and styles.tcss.
- Implemented AgentHomeView to manage agent actions and display metadata.
- Created AgentSelectView for selecting agents with a new layout and functionality.
- Added CreateAgentView for creating new agents with input validation.
- Removed obsolete eval.py and replaced it with evaluation.py.
- Developed GameSelectView for selecting games with a dynamic loading interface.
- Introduced ModelSelectView for selecting models associated with agents.
- Created SelectRunView for managing runs associated with agents.
- Added SubmissionView and TrainingView for handling model training and submission processes.
- Updated requirements.txt to include pyfiglet for ASCII art rendering.
2025-05-26 07:55:58 -07:00

497 lines
14 KiB
Python

import json
from datetime import datetime
from agentm.logic.db import get_db_conn
def get_cached_rom(rom_file: str) -> dict | None:
"""
Retrieve verified ROM metadata from the database by ROM filename.
Args:
rom_file: The filename of the ROM (e.g., 'sfiii3n.zip').
Returns:
A dictionary of ROM metadata if verified, otherwise None.
"""
with get_db_conn() as conn:
cur = conn.execute("""
SELECT sha256, verified_at, title, game_id,
difficulty_min, difficulty_max, characters, keywords
FROM roms WHERE rom_file = ? AND verified = 1
""", (rom_file,))
row = cur.fetchone()
if row:
return {
"sha256": row[0],
"verified": True,
"verified_at": row[1],
"title": row[2],
"rom_file": rom_file,
"game_id": row[3],
"difficulty_min": row[4],
"difficulty_max": row[5],
"characters": json.loads(row[6]) if row[6] else [],
"keywords": json.loads(row[7]) if row[7] else [],
}
return None
def get_all_verified_roms() -> list[dict]:
"""
Return a list of all verified ROMs as dictionaries.
Returns:
A list of dictionaries containing ROM metadata.
"""
with get_db_conn() as conn:
cur = conn.execute("""
SELECT sha256, verified_at, title, game_id, rom_file,
difficulty_min, difficulty_max, characters, keywords
FROM roms WHERE verified = 1
ORDER BY title ASC
""")
rows = cur.fetchall()
return [
{
"sha256": row[0],
"verified": True,
"verified_at": row[1],
"title": row[2],
"game_id": row[3],
"rom_file": row[4],
"difficulty_min": row[5],
"difficulty_max": row[6],
"characters": json.loads(row[7]) if row[7] else [],
"keywords": json.loads(row[8]) if row[8] else [],
}
for row in rows
]
def upsert_rom_record(
title: str,
rom_file: str,
game_id: str,
sha256: str,
difficulty_min: int = None,
difficulty_max: int = None,
characters: list[str] = None,
keywords: list[str] = None
):
"""
Insert or replace a verified ROM entry in the database.
Args:
title: Game title.
rom_file: ROM file name.
game_id: Game ID used by DIAMBRA.
sha256: SHA256 checksum of the ROM.
difficulty_min: Minimum difficulty.
difficulty_max: Maximum difficulty.
characters: List of characters.
keywords: List of keywords.
"""
with get_db_conn() as conn:
conn.execute("""
INSERT OR REPLACE INTO roms (
title,
rom_file,
game_id,
sha256,
difficulty_min,
difficulty_max,
characters,
keywords,
verified,
verified_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 1, ?)
""", (
title,
rom_file,
game_id,
sha256,
difficulty_min,
difficulty_max,
json.dumps(characters or []),
json.dumps(keywords or []),
datetime.utcnow().isoformat()
))
def get_agents_for_game(game_id: str) -> list[dict]:
"""
Retrieve all agents associated with a given game_id.
Args:
game_id: The unique game ID (e.g., 'sfiii3n').
Returns:
A list of agent metadata dictionaries.
"""
with get_db_conn() as conn:
cur = conn.execute("""
SELECT id, name, game_id, created_at, last_updated, agent_type, framework, is_imitation
FROM agents
WHERE game_id = ?
ORDER BY created_at DESC
""", (game_id,))
return [dict(row) for row in cur.fetchall()]
def insert_agent(
name: str,
game_id: str,
config_json: str,
agent_type: str = "PPO",
framework: str = "SB3",
is_imitation: bool = False,
dataset_path: str = "",
notes: str = ""
) -> None:
"""
Insert a new agent into the database.
Args:
name: Agent name.
game_id: Game ID this agent is associated with.
config_json: JSON string of the training config.
agent_type: Algorithm type (e.g., PPO, BC).
framework: Framework name (e.g., SB3, Ray).
is_imitation: Whether this is imitation learning.
dataset_path: Optional path to dataset (for IL).
notes: Developer notes or comments.
"""
with get_db_conn() as conn:
conn.execute("""
INSERT INTO agents (
name,
game_id,
config_json,
agent_type,
framework,
is_imitation,
dataset_path,
notes,
created_at,
last_updated
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
name,
game_id,
config_json,
agent_type,
framework,
int(is_imitation),
dataset_path,
notes,
datetime.utcnow().isoformat(),
datetime.utcnow().isoformat()
))
def delete_agent_by_id(agent_id: int) -> None:
"""
Permanently delete an agent from the database by ID.
Args:
agent_id: The unique agent ID to delete.
"""
with get_db_conn() as conn:
conn.execute("DELETE FROM agents WHERE id = ?", (agent_id,))
def get_agent_by_id(agent_id: int) -> dict | None:
"""
Fetch a single agent's metadata from the database by ID.
Args:
agent_id: The unique agent ID.
Returns:
A dictionary of agent metadata if found, otherwise None.
"""
with get_db_conn() as conn:
cur = conn.execute("""
SELECT id, name, game_id, created_at, last_updated, agent_type,
framework, is_imitation, dataset_path, config_json, notes
FROM agents
WHERE id = ?
""", (agent_id,))
row = cur.fetchone()
if row:
return dict(row)
return None
from datetime import datetime
from agentm.logic.db import get_db_conn
def insert_run(
agent_id: int,
name: str,
config_yaml: str,
notes: str = ""
) -> None:
"""
Insert a new run associated with an agent.
Args:
agent_id: The agent this run belongs to.
name: The name of the run.
config_yaml: Serialized YAML config (frozen) for this run.
notes: Optional notes about the run.
"""
with get_db_conn() as conn:
conn.execute("""
INSERT INTO runs (
agent_id,
name,
config_yaml,
notes,
created_at,
updated_at
) VALUES (?, ?, ?, ?, ?, ?)
""", (
agent_id,
name,
config_yaml,
notes,
datetime.utcnow().isoformat(),
datetime.utcnow().isoformat()
))
def get_runs_for_agent(agent_id: int) -> list[dict]:
"""
Get all runs associated with a specific agent.
Args:
agent_id: The ID of the agent.
Returns:
A list of run dictionaries.
"""
with get_db_conn() as conn:
cur = conn.execute("""
SELECT * FROM runs
WHERE agent_id = ?
ORDER BY created_at DESC
""", (agent_id,))
return [dict(row) for row in cur.fetchall()]
def get_run_by_id(run_id: int) -> dict | None:
"""
Get a specific run by its ID.
Args:
run_id: The ID of the run.
Returns:
A run dictionary or None if not found.
"""
with get_db_conn() as conn:
cur = conn.execute("""
SELECT * FROM runs
WHERE id = ?
""", (run_id,))
row = cur.fetchone()
return dict(row) if row else None
def delete_run_by_id(run_id: int) -> None:
"""
Delete a run permanently from the database.
Args:
run_id: The ID of the run to delete.
"""
with get_db_conn() as conn:
conn.execute("""
DELETE FROM runs
WHERE id = ?
""", (run_id,))
from datetime import datetime
from agentm.logic.db import get_db_conn
def insert_model(
agent_id: int,
name: str,
total_steps_planned: int,
notes: str = "",
parent_model_id: int | None = None,
average_reward: float | None = None,
current_learning_rate: float | None = None,
current_clip_range: float | None = None,
config_patch_yaml: str = "",
checkpoint_path: str = "",
is_frozen: bool = False,
status: str = "pending"
) -> None:
"""
Insert a new model entry.
"""
with get_db_conn() as conn:
conn.execute("""
INSERT INTO models (
agent_id, parent_model_id, name, notes,
total_steps_planned, total_steps_completed, average_reward,
current_learning_rate, current_clip_range,
config_patch_yaml, checkpoint_path, is_frozen,
status, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, 0, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
agent_id, parent_model_id, name, notes,
total_steps_planned, average_reward,
current_learning_rate, current_clip_range,
config_patch_yaml, checkpoint_path, int(is_frozen),
status, datetime.utcnow().isoformat(), datetime.utcnow().isoformat()
))
def get_models_for_agent(agent_id: int, include_children: bool = True) -> list[dict]:
"""
Fetch models for a given agent.
"""
with get_db_conn() as conn:
cur = conn.execute("""
SELECT * FROM models
WHERE agent_id = ?
ORDER BY created_at DESC
""", (agent_id,))
return [dict(row) for row in cur.fetchall()]
def get_model_by_id(model_id: int) -> dict | None:
"""
Fetch a single model by ID.
"""
with get_db_conn() as conn:
cur = conn.execute("SELECT * FROM models WHERE id = ?", (model_id,))
row = cur.fetchone()
return dict(row) if row else None
def update_model_progress(
model_id: int,
total_steps_completed: int,
average_reward: float | None = None,
current_learning_rate: float | None = None,
current_clip_range: float | None = None,
checkpoint_path: str | None = None,
status: str | None = None
):
"""
Update step progress, learning rate, and optionally reward/checkpoint info.
"""
fields = ["total_steps_completed = ?"]
values = [total_steps_completed]
if average_reward is not None:
fields.append("average_reward = ?")
values.append(average_reward)
if current_learning_rate is not None:
fields.append("current_learning_rate = ?")
values.append(current_learning_rate)
if current_clip_range is not None:
fields.append("current_clip_range = ?")
values.append(current_clip_range)
if checkpoint_path is not None:
fields.append("checkpoint_path = ?")
values.append(checkpoint_path)
if status is not None:
fields.append("status = ?")
values.append(status)
fields.append("updated_at = ?")
values.append(datetime.utcnow().isoformat())
values.append(model_id)
with get_db_conn() as conn:
conn.execute(f"""
UPDATE models SET {', '.join(fields)}
WHERE id = ?
""", tuple(values))
def delete_model_by_id(model_id: int) -> None:
"""
Delete a model record.
"""
with get_db_conn() as conn:
conn.execute("DELETE FROM models WHERE id = ?", (model_id,))
def get_models_for_run(run_id: int) -> list[dict]:
"""
Retrieve all models linked to a specific run.
Assumes a run_id is passed and models were inserted with a reference to it (if applicable).
Currently, models are linked by agent_id only, so we filter manually.
Args:
run_id: The ID of the run.
Returns:
A list of model dicts.
"""
with get_db_conn() as conn:
# First, get the agent_id for the run
cur = conn.execute("SELECT agent_id FROM runs WHERE id = ?", (run_id,))
row = cur.fetchone()
if not row:
return []
agent_id = row["agent_id"]
# Then, get all models for that agent
cur = conn.execute("""
SELECT * FROM models
WHERE agent_id = ?
ORDER BY created_at DESC
""", (agent_id,))
return [dict(r) for r in cur.fetchall()]
def update_run(
run_id: int,
name: str | None = None,
config_yaml: str | None = None,
notes: str | None = None,
pending: bool | None = None
) -> None:
"""
Update a run by ID with any combination of fields.
"""
fields = []
values = []
if name is not None:
fields.append("name = ?")
values.append(name)
if config_yaml is not None:
fields.append("config_yaml = ?")
values.append(config_yaml)
if notes is not None:
fields.append("notes = ?")
values.append(notes)
if pending is not None:
fields.append("pending = ?")
values.append(int(pending))
if not fields:
return # Nothing to update
fields.append("updated_at = ?")
values.append(datetime.utcnow().isoformat())
values.append(run_id)
with get_db_conn() as conn:
conn.execute(
f"UPDATE runs SET {', '.join(fields)} WHERE id = ?",
tuple(values)
)