233 lines
8.8 KiB
Python
233 lines
8.8 KiB
Python
from textual.screen import Screen
|
|
from textual.widgets import Static, Button
|
|
from textual.containers import Vertical, Horizontal, HorizontalScroll
|
|
from textual.message import Message
|
|
from textual.reactive import reactive
|
|
from textual.widget import Widget
|
|
from rich.panel import Panel
|
|
from rich.console import Group
|
|
from rich.table import Table
|
|
from rich.rule import Rule
|
|
from rich.markup import escape
|
|
from PIL import Image, ImageFilter
|
|
import os
|
|
|
|
from rich_pixels import Pixels
|
|
from rich_pixels._renderer import HalfcellRenderer
|
|
|
|
from agentm.utils.logger import log_with_caller
|
|
from agentm.logic.roms import get_verified_roms, GAME_FILES
|
|
from agentm.theme.palette import get_theme
|
|
from agentm.components.footer import AgentMFooter
|
|
|
|
palette = get_theme()
|
|
|
|
|
|
class GameSelected(Message):
|
|
def __init__(self, sender: Widget, metadata: dict):
|
|
self.metadata = metadata
|
|
super().__init__(sender)
|
|
|
|
|
|
class ProgressWidget(Widget):
|
|
message = reactive("Loading...")
|
|
|
|
def render(self) -> str:
|
|
return f"[bold {palette.ACCENT}]{self.message}[/]"
|
|
|
|
|
|
|
|
class GameAccordion(Static):
|
|
def __init__(self, title: str, rom_file: str, metadata: dict, parent_view):
|
|
self.title = title
|
|
self.rom_file = rom_file
|
|
self.metadata = metadata
|
|
self.parent_view = parent_view
|
|
self.safe_id = rom_file.replace(".", "_").replace("-", "_")
|
|
|
|
image_path = os.path.abspath(self.metadata.get("image_path", ""))
|
|
self.image_renderable = self.load_image_scaled(image_path)
|
|
|
|
# Title
|
|
self.title_label = Static(
|
|
f"[b {palette.ACCENT}]{escape(self.title.upper())}[/]\n",
|
|
classes="game_title",
|
|
markup=True
|
|
)
|
|
|
|
super().__init__(id=f"accordion_{self.safe_id}", classes="game_card")
|
|
|
|
def load_image_scaled(self, path: str):
|
|
try:
|
|
with Image.open(path) as img:
|
|
if img.mode != "RGBA":
|
|
img = img.convert("RGBA")
|
|
scale_factor = 0.098
|
|
target_width = int(img.width * scale_factor)
|
|
target_height = int(img.height * scale_factor)
|
|
resized = img.resize(
|
|
(target_width, target_height),
|
|
resample=Image.Resampling.LANCZOS
|
|
)
|
|
resized = resized.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3))
|
|
self._render_width = target_width
|
|
self._render_height = target_height // 2
|
|
return Pixels.from_image(
|
|
resized,
|
|
renderer=HalfcellRenderer(default_color="black"),
|
|
)
|
|
except Exception as e:
|
|
self._render_width = 24
|
|
self._render_height = 16
|
|
return f"[red]Failed to load image[/red]\n[dim]{e}]"
|
|
|
|
def compose(self):
|
|
yield self.title_label
|
|
yield Static(self.image_renderable)
|
|
|
|
async def on_click(self):
|
|
await self.display_info()
|
|
self.parent_view.highlight_selected(self)
|
|
|
|
async def display_info(self):
|
|
meta = self.metadata
|
|
log_with_caller("debug", f"Showing shared info for {self.rom_file}")
|
|
|
|
table = Table.grid(expand=True)
|
|
table.add_column(ratio=1)
|
|
table.add_column()
|
|
table.add_row("[b]Title:[/b]", meta['title'])
|
|
table.add_row("[b]Game ID:[/b]", meta['game_id'])
|
|
table.add_row("[b]Difficulty:[/b]", f"{meta.get('difficulty_min')} - {meta.get('difficulty_max')}")
|
|
table.add_row("[b]Characters:[/b]", ", ".join(meta.get("characters", [])))
|
|
table.add_row("[b]Keywords:[/b]", ", ".join(meta.get("keywords", [])))
|
|
table.add_row("[b]SHA256:[/b]", meta["sha256"])
|
|
|
|
self.parent_view.shared_info_box.update(
|
|
Panel(Group(table, Rule(style="dim")), title="Game Info", border_style=palette.BORDER, expand=True)
|
|
)
|
|
self.parent_view.shared_confirm_button.label = f"✅ Confirm {meta['title']}"
|
|
self.parent_view.shared_confirm_button.disabled = False
|
|
self.parent_view.selected_game = meta
|
|
|
|
|
|
class HomeView(Screen):
|
|
BINDINGS = [("escape", "app.quit", "Quit")]
|
|
|
|
def highlight_selected(self, selected_widget: GameAccordion):
|
|
for card in self.rom_scroll_row.children:
|
|
if isinstance(card, GameAccordion):
|
|
card.remove_class("game_card_clicked")
|
|
selected_widget.add_class("game_card_clicked")
|
|
|
|
def compose(self):
|
|
self.logo = Static(
|
|
f"[bold {palette.ACCENT}]\n\n"
|
|
" █████╗ ██████╗ ███████╗ ███╗ ██╗ ████████╗ ███╗ ███╗\n"
|
|
"██╔══██╗ ██╔════╝ ██╔════╝ ████╗ ██║ ╚══██╔══╝ ████╗ ████║\n"
|
|
"███████║ ██║ ███╗ █████╗ ██╔██╗██║ ██║ ██╔████╔██║\n"
|
|
"██╔══██║ ██║ ██║ ██╔══╝ ██║╚████║ ██║ ██║╚██╔╝██║\n"
|
|
"██║ ██║ ╚██████╔╝ ███████╗ ██║ ╚███║ ██║ ██║ ██║\n"
|
|
"╚═╝ ╚═╝ ╚═════╝ ╚══════╝ ╚═╝ ╚══╝ ╚═╝ ╚═╝ ╚═╝\n[/]",
|
|
classes="header",
|
|
expand=False,
|
|
)
|
|
|
|
self.welcome_text = Static(
|
|
"This is an unofficial DIAMBRA launcher to help you easily train, evaluate, and submit RL agents for fighting games.",
|
|
classes="body",
|
|
expand=False
|
|
)
|
|
|
|
self.progress_text = ProgressWidget()
|
|
|
|
self.loading_container = Vertical(
|
|
Static(f"[bold {palette.SUCCESS}]LOADING...[/]", expand=True),
|
|
self.progress_text,
|
|
id="loading_container"
|
|
)
|
|
|
|
# This will be the main container we later modify
|
|
self.dynamic_container = Vertical(self.loading_container, id="dynamic_content")
|
|
|
|
yield Vertical(
|
|
self.logo,
|
|
self.welcome_text,
|
|
self.dynamic_container,
|
|
AgentMFooter(compact=True),
|
|
id="home_screen_container",
|
|
classes="centered_layout"
|
|
)
|
|
|
|
async def on_mount(self):
|
|
log_with_caller("debug", "HomeView mounted. Starting ROM verification.")
|
|
self.selected_game = None
|
|
self.run_worker(self.run_verification, thread=True, exclusive=True, name="rom-verification")
|
|
|
|
def run_verification(self):
|
|
total = len(GAME_FILES)
|
|
verified_roms = get_verified_roms()
|
|
|
|
for idx, rom in enumerate(verified_roms, start=1):
|
|
self.app.call_from_thread(
|
|
lambda title=rom['title'], idx=idx: setattr(
|
|
self.progress_text, "message",
|
|
f"Processing {title} ({idx}/{total})"
|
|
)
|
|
)
|
|
import time
|
|
time.sleep(0.01)
|
|
|
|
self.app.call_from_thread(lambda: self.display_verified_roms(verified_roms))
|
|
|
|
async def display_verified_roms(self, verified_roms):
|
|
log_with_caller("info", f"ROM verification complete. Total: {len(verified_roms)}")
|
|
|
|
self.shared_info_box = Static(
|
|
Panel(
|
|
"[dim]Select a Game From Above to Start[/dim]",
|
|
title="Game Info",
|
|
border_style=palette.BORDER,
|
|
expand=True
|
|
),
|
|
id="game_info_box",
|
|
classes="game_info",
|
|
expand=True
|
|
)
|
|
self.shared_confirm_button = Button(
|
|
"✅ Confirm",
|
|
id="confirm_button",
|
|
classes="confirm_button",
|
|
disabled=True
|
|
)
|
|
self.rom_scroll_row = HorizontalScroll(id="rom_scroll_row", classes="rom_row")
|
|
|
|
new_content = Vertical(
|
|
self.rom_scroll_row,
|
|
Horizontal(
|
|
self.shared_info_box,
|
|
self.shared_confirm_button,
|
|
id="info_row",
|
|
classes="info_confirm_row"
|
|
)
|
|
)
|
|
|
|
# Replace loading content with new UI below logo and welcome
|
|
dynamic_container = self.query_one("#dynamic_content")
|
|
await dynamic_container.remove_children()
|
|
await dynamic_container.mount(new_content)
|
|
|
|
# Populate games
|
|
for rom in verified_roms:
|
|
await self.rom_scroll_row.mount(GameAccordion(
|
|
title=rom["title"],
|
|
rom_file=rom["rom_file"],
|
|
metadata=rom,
|
|
parent_view=self
|
|
))
|
|
|
|
|
|
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
if event.button.id == "confirm_button" and self.selected_game:
|
|
await self.app.push_screen("training", self.selected_game)
|