agent_m/agentm/views/model_select.py

150 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Literal
from textual.screen import Screen
from textual.containers import Vertical, Horizontal, VerticalScroll
from textual.widgets import Static, Button
from agentm.logic.db_functions import get_models_for_agent
from agentm.theme.palette import get_theme
from agentm.utils.logger import log_with_caller
from agentm.components.footer import AgentMFooter
from pathlib import Path
# Import destination views (these are placeholders; replace as needed)
from agentm.views.training import TrainingView
from agentm.views.evaluation import EvaluationView
from agentm.views.submission import SubmissionView
# Ensure the palette is loaded
palette = get_theme()
class ModelCard(Static):
def __init__(self, model_data: dict, parent_view):
super().__init__(classes="agent_card")
self.model_data = model_data
self.parent_view = parent_view
def render(self) -> str:
return f"""
[bold {palette.ACCENT}]{self.model_data['name']}[/]
[dim]Steps:[/] {self.model_data['total_steps_completed']} / {self.model_data['total_steps_planned']}
[dim]Avg Reward:[/] {self.model_data.get('average_reward', '')}
[dim]Status:[/] {self.model_data['status']}
[dim]Created:[/] {self.model_data.get('created_at', '')}
""".strip()
async def on_click(self):
await self.parent_view.display_model_info(self.model_data)
class ModelSelectView(Screen):
BINDINGS = [
("escape", "app.pop_screen", "Back"),
("r", "refresh_models", "Refresh"),
]
def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"], run: dict):
super().__init__()
self.agent_metadata = agent_metadata
self.mode = mode
self.run_metadata = run
def compose(self):
header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_model.txt"
try:
header_text = header_path.read_text()
except FileNotFoundError:
header_text = "=== SELECT MODEL ==="
self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header")
self.subheader = Static(f"[b]{self.agent_metadata['name']}[/b]", classes="subheader")
self.model_list = VerticalScroll(id="agent_scroll")
self.model_info = Horizontal(id="agent_info_panel")
self.select_button = Button("✅ Select Model", id="select_model_btn", classes="confirm_button", disabled=True)
self.create_button = Button(" Create New Model", id="create_model_btn", classes="confirm_button")
yield Vertical(
self.header,
self.subheader,
self.model_list,
self.model_info,
self.select_button,
self.create_button,
AgentMFooter(compact=True),
id="agent_select_layout"
)
async def on_mount(self):
log_with_caller("debug", f"Mounted ModelSelectView for agent_id={self.agent_metadata['id']} mode={self.mode}")
await self.refresh_model_list()
async def refresh_model_list(self):
try:
for child in list(self.model_list.children):
await child.remove()
models = get_models_for_agent(self.agent_metadata["id"])
log_with_caller("info", f"Refreshed: {len(models)} models found for agent {self.agent_metadata['name']}")
if models:
for model in models:
await self.model_list.mount(ModelCard(model, self))
else:
await self.model_list.mount(Static("[dim]No models found for this agent.[/dim]"))
self.selected_model = None
await self.display_model_info(None)
except Exception as e:
log_with_caller("error", f"Error rendering model list: {e}")
async def display_model_info(self, model: dict | None):
self.selected_model = model
await self.model_info.remove_children()
if not model:
await self.model_info.mount(Static("[dim]Select a model to view details[/dim]"))
self.select_button.disabled = True
return
from rich.panel import Panel
from rich.table import Table
table = Table.grid(padding=(0, 1))
table.add_column("Key", style="bold underline")
table.add_column("Value", style=palette.ACCENT, overflow="fold")
table.add_row("Name", model["name"])
table.add_row("Status", model["status"])
table.add_row("Steps", f"{model['total_steps_completed']} / {model['total_steps_planned']}")
table.add_row("Avg Reward", str(model.get("average_reward", "")))
table.add_row("Learning Rate", str(model.get("current_learning_rate", "")))
table.add_row("Clip Range", str(model.get("current_clip_range", "")))
table.add_row("Checkpoint", model.get("checkpoint_path", ""))
table.add_row("Created", model.get("created_at", ""))
info_panel = Static(Panel(table, title="Model Info", border_style=palette.BORDER), classes="agent_info_box")
await self.model_info.mount(info_panel)
self.select_button.disabled = False
async def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.id == "select_model_btn" and self.selected_model:
log_with_caller("info", f"Model confirmed: {self.selected_model['name']} for mode={self.mode}")
match self.mode:
case "train":
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None, run=self.run_metadata))
case "eval":
await self.app.push_screen(EvaluationView(agent=self.agent_metadata, model=self.selected_model))
case "submit":
await self.app.push_screen(SubmissionView(agent=self.agent_metadata, model=self.selected_model))
elif event.button.id == "create_model_btn":
log_with_caller("info", f"Creating new model for agent: {self.agent_metadata['name']}")
await self.app.push_screen(
TrainingView(agent=self.agent_metadata, model=None, run=self.run_metadata)
)
async def action_refresh_models(self):
await self.refresh_model_list()