import json from datetime import datetime from agentm.logic.db import get_db_conn def get_cached_rom(rom_file: str) -> dict | None: """ Retrieve verified ROM metadata from the database by ROM filename. Args: rom_file: The filename of the ROM (e.g., 'sfiii3n.zip'). Returns: A dictionary of ROM metadata if verified, otherwise None. """ with get_db_conn() as conn: cur = conn.execute(""" SELECT sha256, verified_at, title, game_id, difficulty_min, difficulty_max, characters, keywords FROM roms WHERE rom_file = ? AND verified = 1 """, (rom_file,)) row = cur.fetchone() if row: return { "sha256": row[0], "verified": True, "verified_at": row[1], "title": row[2], "rom_file": rom_file, "game_id": row[3], "difficulty_min": row[4], "difficulty_max": row[5], "characters": json.loads(row[6]) if row[6] else [], "keywords": json.loads(row[7]) if row[7] else [], } return None def get_all_verified_roms() -> list[dict]: """ Return a list of all verified ROMs as dictionaries. Returns: A list of dictionaries containing ROM metadata. """ with get_db_conn() as conn: cur = conn.execute(""" SELECT sha256, verified_at, title, game_id, rom_file, difficulty_min, difficulty_max, characters, keywords FROM roms WHERE verified = 1 ORDER BY title ASC """) rows = cur.fetchall() return [ { "sha256": row[0], "verified": True, "verified_at": row[1], "title": row[2], "game_id": row[3], "rom_file": row[4], "difficulty_min": row[5], "difficulty_max": row[6], "characters": json.loads(row[7]) if row[7] else [], "keywords": json.loads(row[8]) if row[8] else [], } for row in rows ] def upsert_rom_record( title: str, rom_file: str, game_id: str, sha256: str, difficulty_min: int = None, difficulty_max: int = None, characters: list[str] = None, keywords: list[str] = None ): """ Insert or replace a verified ROM entry in the database. Args: title: Game title. rom_file: ROM file name. game_id: Game ID used by DIAMBRA. sha256: SHA256 checksum of the ROM. difficulty_min: Minimum difficulty. difficulty_max: Maximum difficulty. characters: List of characters. keywords: List of keywords. """ with get_db_conn() as conn: conn.execute(""" INSERT OR REPLACE INTO roms ( title, rom_file, game_id, sha256, difficulty_min, difficulty_max, characters, keywords, verified, verified_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 1, ?) """, ( title, rom_file, game_id, sha256, difficulty_min, difficulty_max, json.dumps(characters or []), json.dumps(keywords or []), datetime.utcnow().isoformat() )) def get_agents_for_game(game_id: str) -> list[dict]: """ Retrieve all agents associated with a given game_id. Args: game_id: The unique game ID (e.g., 'sfiii3n'). Returns: A list of agent metadata dictionaries. """ with get_db_conn() as conn: cur = conn.execute(""" SELECT id, name, game_id, created_at, last_updated, agent_type, framework, is_imitation FROM agents WHERE game_id = ? ORDER BY created_at DESC """, (game_id,)) return [dict(row) for row in cur.fetchall()] def insert_agent( name: str, game_id: str, config_json: str, agent_type: str = "PPO", framework: str = "SB3", is_imitation: bool = False, dataset_path: str = "", notes: str = "" ) -> None: """ Insert a new agent into the database. Args: name: Agent name. game_id: Game ID this agent is associated with. config_json: JSON string of the training config. agent_type: Algorithm type (e.g., PPO, BC). framework: Framework name (e.g., SB3, Ray). is_imitation: Whether this is imitation learning. dataset_path: Optional path to dataset (for IL). notes: Developer notes or comments. """ with get_db_conn() as conn: conn.execute(""" INSERT INTO agents ( name, game_id, config_json, agent_type, framework, is_imitation, dataset_path, notes, created_at, last_updated ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( name, game_id, config_json, agent_type, framework, int(is_imitation), dataset_path, notes, datetime.utcnow().isoformat(), datetime.utcnow().isoformat() )) def delete_agent_by_id(agent_id: int) -> None: """ Permanently delete an agent from the database by ID. Args: agent_id: The unique agent ID to delete. """ with get_db_conn() as conn: conn.execute("DELETE FROM agents WHERE id = ?", (agent_id,)) def get_agent_by_id(agent_id: int) -> dict | None: """ Fetch a single agent's metadata from the database by ID. Args: agent_id: The unique agent ID. Returns: A dictionary of agent metadata if found, otherwise None. """ with get_db_conn() as conn: cur = conn.execute(""" SELECT id, name, game_id, created_at, last_updated, agent_type, framework, is_imitation, dataset_path, config_json, notes FROM agents WHERE id = ? """, (agent_id,)) row = cur.fetchone() if row: return dict(row) return None from datetime import datetime from agentm.logic.db import get_db_conn def insert_run( agent_id: int, name: str, config_yaml: str, notes: str = "" ) -> None: """ Insert a new run associated with an agent. Args: agent_id: The agent this run belongs to. name: The name of the run. config_yaml: Serialized YAML config (frozen) for this run. notes: Optional notes about the run. """ with get_db_conn() as conn: conn.execute(""" INSERT INTO runs ( agent_id, name, config_yaml, notes, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?) """, ( agent_id, name, config_yaml, notes, datetime.utcnow().isoformat(), datetime.utcnow().isoformat() )) def get_runs_for_agent(agent_id: int) -> list[dict]: """ Get all runs associated with a specific agent. Args: agent_id: The ID of the agent. Returns: A list of run dictionaries. """ with get_db_conn() as conn: cur = conn.execute(""" SELECT * FROM runs WHERE agent_id = ? ORDER BY created_at DESC """, (agent_id,)) return [dict(row) for row in cur.fetchall()] def get_run_by_id(run_id: int) -> dict | None: """ Get a specific run by its ID. Args: run_id: The ID of the run. Returns: A run dictionary or None if not found. """ with get_db_conn() as conn: cur = conn.execute(""" SELECT * FROM runs WHERE id = ? """, (run_id,)) row = cur.fetchone() return dict(row) if row else None def delete_run_by_id(run_id: int) -> None: """ Delete a run permanently from the database. Args: run_id: The ID of the run to delete. """ with get_db_conn() as conn: conn.execute(""" DELETE FROM runs WHERE id = ? """, (run_id,)) from datetime import datetime from agentm.logic.db import get_db_conn def insert_model( agent_id: int, name: str, total_steps_planned: int, notes: str = "", parent_model_id: int | None = None, average_reward: float | None = None, current_learning_rate: float | None = None, current_clip_range: float | None = None, config_patch_yaml: str = "", checkpoint_path: str = "", is_frozen: bool = False, status: str = "pending" ) -> None: """ Insert a new model entry. """ with get_db_conn() as conn: conn.execute(""" INSERT INTO models ( agent_id, parent_model_id, name, notes, total_steps_planned, total_steps_completed, average_reward, current_learning_rate, current_clip_range, config_patch_yaml, checkpoint_path, is_frozen, status, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, 0, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( agent_id, parent_model_id, name, notes, total_steps_planned, average_reward, current_learning_rate, current_clip_range, config_patch_yaml, checkpoint_path, int(is_frozen), status, datetime.utcnow().isoformat(), datetime.utcnow().isoformat() )) def get_models_for_agent(agent_id: int, include_children: bool = True) -> list[dict]: """ Fetch models for a given agent. """ with get_db_conn() as conn: cur = conn.execute(""" SELECT * FROM models WHERE agent_id = ? ORDER BY created_at DESC """, (agent_id,)) return [dict(row) for row in cur.fetchall()] def get_model_by_id(model_id: int) -> dict | None: """ Fetch a single model by ID. """ with get_db_conn() as conn: cur = conn.execute("SELECT * FROM models WHERE id = ?", (model_id,)) row = cur.fetchone() return dict(row) if row else None def update_model_progress( model_id: int, total_steps_completed: int, average_reward: float | None = None, current_learning_rate: float | None = None, current_clip_range: float | None = None, checkpoint_path: str | None = None, status: str | None = None ): """ Update step progress, learning rate, and optionally reward/checkpoint info. """ fields = ["total_steps_completed = ?"] values = [total_steps_completed] if average_reward is not None: fields.append("average_reward = ?") values.append(average_reward) if current_learning_rate is not None: fields.append("current_learning_rate = ?") values.append(current_learning_rate) if current_clip_range is not None: fields.append("current_clip_range = ?") values.append(current_clip_range) if checkpoint_path is not None: fields.append("checkpoint_path = ?") values.append(checkpoint_path) if status is not None: fields.append("status = ?") values.append(status) fields.append("updated_at = ?") values.append(datetime.utcnow().isoformat()) values.append(model_id) with get_db_conn() as conn: conn.execute(f""" UPDATE models SET {', '.join(fields)} WHERE id = ? """, tuple(values)) def delete_model_by_id(model_id: int) -> None: """ Delete a model record. """ with get_db_conn() as conn: conn.execute("DELETE FROM models WHERE id = ?", (model_id,)) def get_models_for_run(run_id: int) -> list[dict]: """ Retrieve all models linked to a specific run. Assumes a run_id is passed and models were inserted with a reference to it (if applicable). Currently, models are linked by agent_id only, so we filter manually. Args: run_id: The ID of the run. Returns: A list of model dicts. """ with get_db_conn() as conn: # First, get the agent_id for the run cur = conn.execute("SELECT agent_id FROM runs WHERE id = ?", (run_id,)) row = cur.fetchone() if not row: return [] agent_id = row["agent_id"] # Then, get all models for that agent cur = conn.execute(""" SELECT * FROM models WHERE agent_id = ? ORDER BY created_at DESC """, (agent_id,)) return [dict(r) for r in cur.fetchall()] def update_run( run_id: int, name: str | None = None, config_yaml: str | None = None, notes: str | None = None, pending: bool | None = None ) -> None: """ Update a run by ID with any combination of fields. """ fields = [] values = [] if name is not None: fields.append("name = ?") values.append(name) if config_yaml is not None: fields.append("config_yaml = ?") values.append(config_yaml) if notes is not None: fields.append("notes = ?") values.append(notes) if pending is not None: fields.append("pending = ?") values.append(int(pending)) if not fields: return # Nothing to update fields.append("updated_at = ?") values.append(datetime.utcnow().isoformat()) values.append(run_id) with get_db_conn() as conn: conn.execute( f"UPDATE runs SET {', '.join(fields)} WHERE id = ?", tuple(values) ) def update_run_pending(run_id: int, pending: bool) -> None: """ Update the pending status of a run. Args: run_id: The ID of the run to update. pending: New pending status (True/False). """ with get_db_conn() as conn: conn.execute(""" UPDATE runs SET pending = ?, updated_at = ? WHERE id = ? """, (int(pending), datetime.utcnow().isoformat(), run_id))