- Removed commented-out header styles from styles.base.tcss and styles.tcss. - Added new styles for danger buttons and agent selection views in styles.base.tcss and styles.tcss. - Implemented AgentHomeView to manage agent actions and display metadata. - Created AgentSelectView for selecting agents with a new layout and functionality. - Added CreateAgentView for creating new agents with input validation. - Removed obsolete eval.py and replaced it with evaluation.py. - Developed GameSelectView for selecting games with a dynamic loading interface. - Introduced ModelSelectView for selecting models associated with agents. - Created SelectRunView for managing runs associated with agents. - Added SubmissionView and TrainingView for handling model training and submission processes. - Updated requirements.txt to include pyfiglet for ASCII art rendering.
147 lines
6.1 KiB
Python
147 lines
6.1 KiB
Python
from typing import Literal
|
||
from textual.screen import Screen
|
||
from textual.containers import Vertical, Horizontal, VerticalScroll
|
||
from textual.widgets import Static, Button
|
||
from agentm.logic.db_functions import get_models_for_agent
|
||
from agentm.theme.palette import get_theme
|
||
from agentm.utils.logger import log_with_caller
|
||
from agentm.components.footer import AgentMFooter
|
||
from pathlib import Path
|
||
|
||
# Import destination views (these are placeholders; replace as needed)
|
||
from agentm.views.training import TrainingView
|
||
from agentm.views.evaluation import EvaluationView
|
||
from agentm.views.submission import SubmissionView
|
||
|
||
# Ensure the palette is loaded
|
||
palette = get_theme()
|
||
|
||
|
||
class ModelCard(Static):
|
||
def __init__(self, model_data: dict, parent_view):
|
||
super().__init__(classes="agent_card")
|
||
self.model_data = model_data
|
||
self.parent_view = parent_view
|
||
|
||
def render(self) -> str:
|
||
return f"""
|
||
[bold {palette.ACCENT}]{self.model_data['name']}[/]
|
||
[dim]Steps:[/] {self.model_data['total_steps_completed']} / {self.model_data['total_steps_planned']}
|
||
[dim]Avg Reward:[/] {self.model_data.get('average_reward', '—')}
|
||
[dim]Status:[/] {self.model_data['status']}
|
||
[dim]Created:[/] {self.model_data.get('created_at', '—')}
|
||
""".strip()
|
||
|
||
async def on_click(self):
|
||
await self.parent_view.display_model_info(self.model_data)
|
||
|
||
|
||
class ModelSelectView(Screen):
|
||
BINDINGS = [
|
||
("escape", "app.pop_screen", "Back"),
|
||
("r", "refresh_models", "Refresh"),
|
||
]
|
||
|
||
def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"]):
|
||
super().__init__()
|
||
self.agent_metadata = agent_metadata
|
||
self.mode = mode
|
||
self.selected_model = None
|
||
|
||
def compose(self):
|
||
header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_model.txt"
|
||
try:
|
||
header_text = header_path.read_text()
|
||
except FileNotFoundError:
|
||
header_text = "=== SELECT MODEL ==="
|
||
|
||
self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header")
|
||
self.subheader = Static(f"[b]{self.agent_metadata['name']}[/b]", classes="subheader")
|
||
|
||
self.model_list = VerticalScroll(id="agent_scroll")
|
||
self.model_info = Horizontal(id="agent_info_panel")
|
||
|
||
self.select_button = Button("✅ Select Model", id="select_model_btn", classes="confirm_button", disabled=True)
|
||
self.create_button = Button("➕ Create New Model", id="create_model_btn", classes="confirm_button")
|
||
|
||
yield Vertical(
|
||
self.header,
|
||
self.subheader,
|
||
self.model_list,
|
||
self.model_info,
|
||
self.select_button,
|
||
self.create_button,
|
||
AgentMFooter(compact=True),
|
||
id="agent_select_layout"
|
||
)
|
||
|
||
async def on_mount(self):
|
||
log_with_caller("debug", f"Mounted ModelSelectView for agent_id={self.agent_metadata['id']} mode={self.mode}")
|
||
await self.refresh_model_list()
|
||
|
||
async def refresh_model_list(self):
|
||
try:
|
||
for child in list(self.model_list.children):
|
||
await child.remove()
|
||
|
||
models = get_models_for_agent(self.agent_metadata["id"])
|
||
log_with_caller("info", f"Refreshed: {len(models)} models found for agent {self.agent_metadata['name']}")
|
||
|
||
if models:
|
||
for model in models:
|
||
await self.model_list.mount(ModelCard(model, self))
|
||
else:
|
||
await self.model_list.mount(Static("[dim]No models found for this agent.[/dim]"))
|
||
|
||
self.selected_model = None
|
||
await self.display_model_info(None)
|
||
except Exception as e:
|
||
log_with_caller("error", f"Error rendering model list: {e}")
|
||
|
||
async def display_model_info(self, model: dict | None):
|
||
self.selected_model = model
|
||
await self.model_info.remove_children()
|
||
|
||
if not model:
|
||
await self.model_info.mount(Static("[dim]Select a model to view details[/dim]"))
|
||
self.select_button.disabled = True
|
||
return
|
||
|
||
from rich.panel import Panel
|
||
from rich.table import Table
|
||
|
||
table = Table.grid(padding=(0, 1))
|
||
table.add_column("Key", style="bold underline")
|
||
table.add_column("Value", style=palette.ACCENT, overflow="fold")
|
||
|
||
table.add_row("Name", model["name"])
|
||
table.add_row("Status", model["status"])
|
||
table.add_row("Steps", f"{model['total_steps_completed']} / {model['total_steps_planned']}")
|
||
table.add_row("Avg Reward", str(model.get("average_reward", "—")))
|
||
table.add_row("Learning Rate", str(model.get("current_learning_rate", "—")))
|
||
table.add_row("Clip Range", str(model.get("current_clip_range", "—")))
|
||
table.add_row("Checkpoint", model.get("checkpoint_path", "—"))
|
||
table.add_row("Created", model.get("created_at", "—"))
|
||
|
||
info_panel = Static(Panel(table, title="Model Info", border_style=palette.BORDER), classes="agent_info_box")
|
||
await self.model_info.mount(info_panel)
|
||
self.select_button.disabled = False
|
||
|
||
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||
if event.button.id == "select_model_btn" and self.selected_model:
|
||
log_with_caller("info", f"Model confirmed: {self.selected_model['name']} for mode={self.mode}")
|
||
match self.mode:
|
||
case "train":
|
||
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=self.selected_model))
|
||
case "eval":
|
||
await self.app.push_screen(EvaluationView(agent=self.agent_metadata, model=self.selected_model))
|
||
case "submit":
|
||
await self.app.push_screen(SubmissionView(agent=self.agent_metadata, model=self.selected_model))
|
||
|
||
elif event.button.id == "create_model_btn":
|
||
log_with_caller("info", f"Creating new model for agent: {self.agent_metadata['name']}")
|
||
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None)) # New model starts from training
|
||
|
||
async def action_refresh_models(self):
|
||
await self.refresh_model_list()
|