from typing import Literal from textual.screen import Screen from textual.containers import Vertical, Horizontal, VerticalScroll from textual.widgets import Static, Button from agentm.logic.db_functions import get_models_for_agent from agentm.theme.palette import get_theme from agentm.utils.logger import log_with_caller from agentm.components.footer import AgentMFooter from pathlib import Path # Import destination views (these are placeholders; replace as needed) from agentm.views.training import TrainingView from agentm.views.evaluation import EvaluationView from agentm.views.submission import SubmissionView # Ensure the palette is loaded palette = get_theme() class ModelCard(Static): def __init__(self, model_data: dict, parent_view): super().__init__(classes="agent_card") self.model_data = model_data self.parent_view = parent_view def render(self) -> str: return f""" [bold {palette.ACCENT}]{self.model_data['name']}[/] [dim]Steps:[/] {self.model_data['total_steps_completed']} / {self.model_data['total_steps_planned']} [dim]Avg Reward:[/] {self.model_data.get('average_reward', '—')} [dim]Status:[/] {self.model_data['status']} [dim]Created:[/] {self.model_data.get('created_at', '—')} """.strip() async def on_click(self): await self.parent_view.display_model_info(self.model_data) class ModelSelectView(Screen): BINDINGS = [ ("escape", "app.pop_screen", "Back"), ("r", "refresh_models", "Refresh"), ] def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"], run: dict): super().__init__() self.agent_metadata = agent_metadata self.mode = mode self.run_metadata = run def compose(self): header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_model.txt" try: header_text = header_path.read_text() except FileNotFoundError: header_text = "=== SELECT MODEL ===" self.header = Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header") self.subheader = Static(f"[b]{self.agent_metadata['name']}[/b]", classes="subheader") self.model_list = VerticalScroll(id="agent_scroll") self.model_info = Horizontal(id="agent_info_panel") self.select_button = Button("✅ Select Model", id="select_model_btn", classes="confirm_button", disabled=True) self.create_button = Button("➕ Create New Model", id="create_model_btn", classes="confirm_button") yield Vertical( self.header, self.subheader, self.model_list, self.model_info, self.select_button, self.create_button, AgentMFooter(compact=True), id="agent_select_layout" ) async def on_mount(self): log_with_caller("debug", f"Mounted ModelSelectView for agent_id={self.agent_metadata['id']} mode={self.mode}") await self.refresh_model_list() async def refresh_model_list(self): try: for child in list(self.model_list.children): await child.remove() models = get_models_for_agent(self.agent_metadata["id"]) log_with_caller("info", f"Refreshed: {len(models)} models found for agent {self.agent_metadata['name']}") if models: for model in models: await self.model_list.mount(ModelCard(model, self)) else: await self.model_list.mount(Static("[dim]No models found for this agent.[/dim]")) self.selected_model = None await self.display_model_info(None) except Exception as e: log_with_caller("error", f"Error rendering model list: {e}") async def display_model_info(self, model: dict | None): self.selected_model = model await self.model_info.remove_children() if not model: await self.model_info.mount(Static("[dim]Select a model to view details[/dim]")) self.select_button.disabled = True return from rich.panel import Panel from rich.table import Table table = Table.grid(padding=(0, 1)) table.add_column("Key", style="bold underline") table.add_column("Value", style=palette.ACCENT, overflow="fold") table.add_row("Name", model["name"]) table.add_row("Status", model["status"]) table.add_row("Steps", f"{model['total_steps_completed']} / {model['total_steps_planned']}") table.add_row("Avg Reward", str(model.get("average_reward", "—"))) table.add_row("Learning Rate", str(model.get("current_learning_rate", "—"))) table.add_row("Clip Range", str(model.get("current_clip_range", "—"))) table.add_row("Checkpoint", model.get("checkpoint_path", "—")) table.add_row("Created", model.get("created_at", "—")) info_panel = Static(Panel(table, title="Model Info", border_style=palette.BORDER), classes="agent_info_box") await self.model_info.mount(info_panel) self.select_button.disabled = False async def on_button_pressed(self, event: Button.Pressed) -> None: if event.button.id == "select_model_btn" and self.selected_model: log_with_caller("info", f"Model confirmed: {self.selected_model['name']} for mode={self.mode}") match self.mode: case "train": await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None, run=self.run_metadata)) case "eval": await self.app.push_screen(EvaluationView(agent=self.agent_metadata, model=self.selected_model)) case "submit": await self.app.push_screen(SubmissionView(agent=self.agent_metadata, model=self.selected_model)) elif event.button.id == "create_model_btn": log_with_caller("info", f"Creating new model for agent: {self.agent_metadata['name']}") await self.app.push_screen( TrainingView(agent=self.agent_metadata, model=None, run=self.run_metadata) ) async def action_refresh_models(self): await self.refresh_model_list()