512 lines
14 KiB
Python
512 lines
14 KiB
Python
import json
|
|
from datetime import datetime
|
|
from agentm.logic.db import get_db_conn
|
|
|
|
|
|
def get_cached_rom(rom_file: str) -> dict | None:
|
|
"""
|
|
Retrieve verified ROM metadata from the database by ROM filename.
|
|
|
|
Args:
|
|
rom_file: The filename of the ROM (e.g., 'sfiii3n.zip').
|
|
|
|
Returns:
|
|
A dictionary of ROM metadata if verified, otherwise None.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
cur = conn.execute("""
|
|
SELECT sha256, verified_at, title, game_id,
|
|
difficulty_min, difficulty_max, characters, keywords
|
|
FROM roms WHERE rom_file = ? AND verified = 1
|
|
""", (rom_file,))
|
|
row = cur.fetchone()
|
|
|
|
if row:
|
|
return {
|
|
"sha256": row[0],
|
|
"verified": True,
|
|
"verified_at": row[1],
|
|
"title": row[2],
|
|
"rom_file": rom_file,
|
|
"game_id": row[3],
|
|
"difficulty_min": row[4],
|
|
"difficulty_max": row[5],
|
|
"characters": json.loads(row[6]) if row[6] else [],
|
|
"keywords": json.loads(row[7]) if row[7] else [],
|
|
}
|
|
return None
|
|
|
|
|
|
def get_all_verified_roms() -> list[dict]:
|
|
"""
|
|
Return a list of all verified ROMs as dictionaries.
|
|
|
|
Returns:
|
|
A list of dictionaries containing ROM metadata.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
cur = conn.execute("""
|
|
SELECT sha256, verified_at, title, game_id, rom_file,
|
|
difficulty_min, difficulty_max, characters, keywords
|
|
FROM roms WHERE verified = 1
|
|
ORDER BY title ASC
|
|
""")
|
|
rows = cur.fetchall()
|
|
|
|
return [
|
|
{
|
|
"sha256": row[0],
|
|
"verified": True,
|
|
"verified_at": row[1],
|
|
"title": row[2],
|
|
"game_id": row[3],
|
|
"rom_file": row[4],
|
|
"difficulty_min": row[5],
|
|
"difficulty_max": row[6],
|
|
"characters": json.loads(row[7]) if row[7] else [],
|
|
"keywords": json.loads(row[8]) if row[8] else [],
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
|
|
def upsert_rom_record(
|
|
title: str,
|
|
rom_file: str,
|
|
game_id: str,
|
|
sha256: str,
|
|
difficulty_min: int = None,
|
|
difficulty_max: int = None,
|
|
characters: list[str] = None,
|
|
keywords: list[str] = None
|
|
):
|
|
"""
|
|
Insert or replace a verified ROM entry in the database.
|
|
|
|
Args:
|
|
title: Game title.
|
|
rom_file: ROM file name.
|
|
game_id: Game ID used by DIAMBRA.
|
|
sha256: SHA256 checksum of the ROM.
|
|
difficulty_min: Minimum difficulty.
|
|
difficulty_max: Maximum difficulty.
|
|
characters: List of characters.
|
|
keywords: List of keywords.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
conn.execute("""
|
|
INSERT OR REPLACE INTO roms (
|
|
title,
|
|
rom_file,
|
|
game_id,
|
|
sha256,
|
|
difficulty_min,
|
|
difficulty_max,
|
|
characters,
|
|
keywords,
|
|
verified,
|
|
verified_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, 1, ?)
|
|
""", (
|
|
title,
|
|
rom_file,
|
|
game_id,
|
|
sha256,
|
|
difficulty_min,
|
|
difficulty_max,
|
|
json.dumps(characters or []),
|
|
json.dumps(keywords or []),
|
|
datetime.utcnow().isoformat()
|
|
))
|
|
|
|
def get_agents_for_game(game_id: str) -> list[dict]:
|
|
"""
|
|
Retrieve all agents associated with a given game_id.
|
|
|
|
Args:
|
|
game_id: The unique game ID (e.g., 'sfiii3n').
|
|
|
|
Returns:
|
|
A list of agent metadata dictionaries.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
cur = conn.execute("""
|
|
SELECT id, name, game_id, created_at, last_updated, agent_type, framework, is_imitation
|
|
FROM agents
|
|
WHERE game_id = ?
|
|
ORDER BY created_at DESC
|
|
""", (game_id,))
|
|
return [dict(row) for row in cur.fetchall()]
|
|
|
|
|
|
def insert_agent(
|
|
name: str,
|
|
game_id: str,
|
|
config_json: str,
|
|
agent_type: str = "PPO",
|
|
framework: str = "SB3",
|
|
is_imitation: bool = False,
|
|
dataset_path: str = "",
|
|
notes: str = ""
|
|
) -> None:
|
|
"""
|
|
Insert a new agent into the database.
|
|
|
|
Args:
|
|
name: Agent name.
|
|
game_id: Game ID this agent is associated with.
|
|
config_json: JSON string of the training config.
|
|
agent_type: Algorithm type (e.g., PPO, BC).
|
|
framework: Framework name (e.g., SB3, Ray).
|
|
is_imitation: Whether this is imitation learning.
|
|
dataset_path: Optional path to dataset (for IL).
|
|
notes: Developer notes or comments.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
conn.execute("""
|
|
INSERT INTO agents (
|
|
name,
|
|
game_id,
|
|
config_json,
|
|
agent_type,
|
|
framework,
|
|
is_imitation,
|
|
dataset_path,
|
|
notes,
|
|
created_at,
|
|
last_updated
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
name,
|
|
game_id,
|
|
config_json,
|
|
agent_type,
|
|
framework,
|
|
int(is_imitation),
|
|
dataset_path,
|
|
notes,
|
|
datetime.utcnow().isoformat(),
|
|
datetime.utcnow().isoformat()
|
|
))
|
|
|
|
def delete_agent_by_id(agent_id: int) -> None:
|
|
"""
|
|
Permanently delete an agent from the database by ID.
|
|
|
|
Args:
|
|
agent_id: The unique agent ID to delete.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
conn.execute("DELETE FROM agents WHERE id = ?", (agent_id,))
|
|
|
|
|
|
def get_agent_by_id(agent_id: int) -> dict | None:
|
|
"""
|
|
Fetch a single agent's metadata from the database by ID.
|
|
|
|
Args:
|
|
agent_id: The unique agent ID.
|
|
|
|
Returns:
|
|
A dictionary of agent metadata if found, otherwise None.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
cur = conn.execute("""
|
|
SELECT id, name, game_id, created_at, last_updated, agent_type,
|
|
framework, is_imitation, dataset_path, config_json, notes
|
|
FROM agents
|
|
WHERE id = ?
|
|
""", (agent_id,))
|
|
row = cur.fetchone()
|
|
|
|
if row:
|
|
return dict(row)
|
|
return None
|
|
|
|
from datetime import datetime
|
|
from agentm.logic.db import get_db_conn
|
|
|
|
|
|
def insert_run(
|
|
agent_id: int,
|
|
name: str,
|
|
config_yaml: str,
|
|
notes: str = ""
|
|
) -> None:
|
|
"""
|
|
Insert a new run associated with an agent.
|
|
|
|
Args:
|
|
agent_id: The agent this run belongs to.
|
|
name: The name of the run.
|
|
config_yaml: Serialized YAML config (frozen) for this run.
|
|
notes: Optional notes about the run.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
conn.execute("""
|
|
INSERT INTO runs (
|
|
agent_id,
|
|
name,
|
|
config_yaml,
|
|
notes,
|
|
created_at,
|
|
updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
agent_id,
|
|
name,
|
|
config_yaml,
|
|
notes,
|
|
datetime.utcnow().isoformat(),
|
|
datetime.utcnow().isoformat()
|
|
))
|
|
|
|
|
|
def get_runs_for_agent(agent_id: int) -> list[dict]:
|
|
"""
|
|
Get all runs associated with a specific agent.
|
|
|
|
Args:
|
|
agent_id: The ID of the agent.
|
|
|
|
Returns:
|
|
A list of run dictionaries.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
cur = conn.execute("""
|
|
SELECT * FROM runs
|
|
WHERE agent_id = ?
|
|
ORDER BY created_at DESC
|
|
""", (agent_id,))
|
|
return [dict(row) for row in cur.fetchall()]
|
|
|
|
|
|
def get_run_by_id(run_id: int) -> dict | None:
|
|
"""
|
|
Get a specific run by its ID.
|
|
|
|
Args:
|
|
run_id: The ID of the run.
|
|
|
|
Returns:
|
|
A run dictionary or None if not found.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
cur = conn.execute("""
|
|
SELECT * FROM runs
|
|
WHERE id = ?
|
|
""", (run_id,))
|
|
row = cur.fetchone()
|
|
return dict(row) if row else None
|
|
|
|
|
|
def delete_run_by_id(run_id: int) -> None:
|
|
"""
|
|
Delete a run permanently from the database.
|
|
|
|
Args:
|
|
run_id: The ID of the run to delete.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
conn.execute("""
|
|
DELETE FROM runs
|
|
WHERE id = ?
|
|
""", (run_id,))
|
|
|
|
from datetime import datetime
|
|
from agentm.logic.db import get_db_conn
|
|
|
|
def insert_model(
|
|
agent_id: int,
|
|
name: str,
|
|
total_steps_planned: int,
|
|
notes: str = "",
|
|
parent_model_id: int | None = None,
|
|
average_reward: float | None = None,
|
|
current_learning_rate: float | None = None,
|
|
current_clip_range: float | None = None,
|
|
config_patch_yaml: str = "",
|
|
checkpoint_path: str = "",
|
|
is_frozen: bool = False,
|
|
status: str = "pending"
|
|
) -> None:
|
|
"""
|
|
Insert a new model entry.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
conn.execute("""
|
|
INSERT INTO models (
|
|
agent_id, parent_model_id, name, notes,
|
|
total_steps_planned, total_steps_completed, average_reward,
|
|
current_learning_rate, current_clip_range,
|
|
config_patch_yaml, checkpoint_path, is_frozen,
|
|
status, created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, 0, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
agent_id, parent_model_id, name, notes,
|
|
total_steps_planned, average_reward,
|
|
current_learning_rate, current_clip_range,
|
|
config_patch_yaml, checkpoint_path, int(is_frozen),
|
|
status, datetime.utcnow().isoformat(), datetime.utcnow().isoformat()
|
|
))
|
|
|
|
def get_models_for_agent(agent_id: int, include_children: bool = True) -> list[dict]:
|
|
"""
|
|
Fetch models for a given agent.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
cur = conn.execute("""
|
|
SELECT * FROM models
|
|
WHERE agent_id = ?
|
|
ORDER BY created_at DESC
|
|
""", (agent_id,))
|
|
return [dict(row) for row in cur.fetchall()]
|
|
|
|
def get_model_by_id(model_id: int) -> dict | None:
|
|
"""
|
|
Fetch a single model by ID.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
cur = conn.execute("SELECT * FROM models WHERE id = ?", (model_id,))
|
|
row = cur.fetchone()
|
|
return dict(row) if row else None
|
|
|
|
def update_model_progress(
|
|
model_id: int,
|
|
total_steps_completed: int,
|
|
average_reward: float | None = None,
|
|
current_learning_rate: float | None = None,
|
|
current_clip_range: float | None = None,
|
|
checkpoint_path: str | None = None,
|
|
status: str | None = None
|
|
):
|
|
"""
|
|
Update step progress, learning rate, and optionally reward/checkpoint info.
|
|
"""
|
|
fields = ["total_steps_completed = ?"]
|
|
values = [total_steps_completed]
|
|
|
|
if average_reward is not None:
|
|
fields.append("average_reward = ?")
|
|
values.append(average_reward)
|
|
|
|
if current_learning_rate is not None:
|
|
fields.append("current_learning_rate = ?")
|
|
values.append(current_learning_rate)
|
|
|
|
if current_clip_range is not None:
|
|
fields.append("current_clip_range = ?")
|
|
values.append(current_clip_range)
|
|
|
|
if checkpoint_path is not None:
|
|
fields.append("checkpoint_path = ?")
|
|
values.append(checkpoint_path)
|
|
|
|
if status is not None:
|
|
fields.append("status = ?")
|
|
values.append(status)
|
|
|
|
fields.append("updated_at = ?")
|
|
values.append(datetime.utcnow().isoformat())
|
|
|
|
values.append(model_id)
|
|
|
|
with get_db_conn() as conn:
|
|
conn.execute(f"""
|
|
UPDATE models SET {', '.join(fields)}
|
|
WHERE id = ?
|
|
""", tuple(values))
|
|
|
|
def delete_model_by_id(model_id: int) -> None:
|
|
"""
|
|
Delete a model record.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
conn.execute("DELETE FROM models WHERE id = ?", (model_id,))
|
|
|
|
def get_models_for_run(run_id: int) -> list[dict]:
|
|
"""
|
|
Retrieve all models linked to a specific run.
|
|
Assumes a run_id is passed and models were inserted with a reference to it (if applicable).
|
|
Currently, models are linked by agent_id only, so we filter manually.
|
|
|
|
Args:
|
|
run_id: The ID of the run.
|
|
|
|
Returns:
|
|
A list of model dicts.
|
|
"""
|
|
with get_db_conn() as conn:
|
|
# First, get the agent_id for the run
|
|
cur = conn.execute("SELECT agent_id FROM runs WHERE id = ?", (run_id,))
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return []
|
|
|
|
agent_id = row["agent_id"]
|
|
# Then, get all models for that agent
|
|
cur = conn.execute("""
|
|
SELECT * FROM models
|
|
WHERE agent_id = ?
|
|
ORDER BY created_at DESC
|
|
""", (agent_id,))
|
|
return [dict(r) for r in cur.fetchall()]
|
|
|
|
|
|
def update_run(
|
|
run_id: int,
|
|
name: str | None = None,
|
|
config_yaml: str | None = None,
|
|
notes: str | None = None,
|
|
pending: bool | None = None
|
|
) -> None:
|
|
"""
|
|
Update a run by ID with any combination of fields.
|
|
"""
|
|
fields = []
|
|
values = []
|
|
|
|
if name is not None:
|
|
fields.append("name = ?")
|
|
values.append(name)
|
|
|
|
if config_yaml is not None:
|
|
fields.append("config_yaml = ?")
|
|
values.append(config_yaml)
|
|
|
|
if notes is not None:
|
|
fields.append("notes = ?")
|
|
values.append(notes)
|
|
|
|
if pending is not None:
|
|
fields.append("pending = ?")
|
|
values.append(int(pending))
|
|
|
|
if not fields:
|
|
return # Nothing to update
|
|
|
|
fields.append("updated_at = ?")
|
|
values.append(datetime.utcnow().isoformat())
|
|
values.append(run_id)
|
|
|
|
with get_db_conn() as conn:
|
|
conn.execute(
|
|
f"UPDATE runs SET {', '.join(fields)} WHERE id = ?",
|
|
tuple(values)
|
|
)
|
|
|
|
def update_run_pending(run_id: int, pending: bool) -> None:
|
|
"""
|
|
Update the pending status of a run.
|
|
|
|
Args:
|
|
run_id: The ID of the run to update.
|
|
pending: New pending status (True/False).
|
|
"""
|
|
with get_db_conn() as conn:
|
|
conn.execute("""
|
|
UPDATE runs
|
|
SET pending = ?, updated_at = ?
|
|
WHERE id = ?
|
|
""", (int(pending), datetime.utcnow().isoformat(), run_id))
|
|
|