styling on training added.
This commit is contained in:
parent
4500bfd388
commit
fb6e2fcd65
10
agentm/assets/headers/training_setup.txt
Normal file
10
agentm/assets/headers/training_setup.txt
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
████████╗██████╗ █████╗ ██╗███╗ ██╗██╗███╗ ██╗ ██████╗ ███████╗███████╗████████╗██╗ ██╗██████╗
|
||||
╚══██╔══╝██╔══██╗██╔══██╗██║████╗ ██║██║████╗ ██║██╔════╝ ██╔════╝██╔════╝╚══██╔══╝██║ ██║██╔══██╗
|
||||
██║ ██████╔╝███████║██║██╔██╗ ██║██║██╔██╗ ██║██║ ███╗ ███████╗█████╗ ██║ ██║ ██║██████╔╝
|
||||
██║ ██╔══██╗██╔══██║██║██║╚██╗██║██║██║╚██╗██║██║ ██║ ╚════██║██╔══╝ ██║ ██║ ██║██╔═══╝
|
||||
██║ ██║ ██║██║ ██║██║██║ ╚████║██║██║ ╚████║╚██████╔╝ ███████║███████╗ ██║ ╚██████╔╝██║
|
||||
╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚══════╝ ╚═╝ ╚═════╝ ╚═╝
|
||||
|
||||
|
||||
10
agentm/data/game_data/doapp.yaml
Normal file
10
agentm/data/game_data/doapp.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
██████╗ ███████╗ █████╗ ██████╗ ██████╗ ██████╗ █████╗ ██╗ ██╗██╗ ██╗███████╗
|
||||
██╔══██╗██╔════╝██╔══██╗██╔══██╗ ██╔═══██╗██╔══██╗ ██╔══██╗██║ ██║██║ ██║██╔════╝
|
||||
██║ ██║█████╗ ███████║██║ ██║ ██║ ██║██████╔╝ ███████║██║ ██║██║ ██║█████╗
|
||||
██║ ██║██╔══╝ ██╔══██║██║ ██║ ██║ ██║██╔══██╗ ██╔══██║██║ ██║╚██╗ ██╔╝██╔══╝
|
||||
██████╔╝███████╗██║ ██║██████╔╝ ╚██████╔╝██║ ██║ ██║ ██║███████╗██║ ╚████╔╝ ███████╗
|
||||
╚═════╝ ╚══════╝╚═╝ ╚═╝╚═════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚══════╝
|
||||
|
||||
|
||||
41
agentm/data/game_data/general.yaml
Normal file
41
agentm/data/game_data/general.yaml
Normal file
@ -0,0 +1,41 @@
|
||||
folders:
|
||||
parent_dir: "./results/"
|
||||
model_name: "default_model"
|
||||
|
||||
settings:
|
||||
step_ratio: 6
|
||||
frame_shape: [128, 128, 1]
|
||||
continue_game: 0.0
|
||||
action_space: "discrete"
|
||||
|
||||
wrappers_settings:
|
||||
wrappers:
|
||||
normalize_reward: true
|
||||
no_attack_buttons_combinations: true
|
||||
stack_frames: 4
|
||||
dilation: 1
|
||||
add_last_action: true
|
||||
stack_actions: 12
|
||||
scale: true
|
||||
exclude_image_scaling: true
|
||||
role_relative: true
|
||||
flatten: true
|
||||
filter_keys:
|
||||
- frame
|
||||
- action
|
||||
- stage
|
||||
- timer
|
||||
|
||||
policy_kwargs:
|
||||
net_arch: [64, 64]
|
||||
|
||||
ppo_settings:
|
||||
gamma: 0.94
|
||||
model_checkpoint: "0"
|
||||
learning_rate: [2.5e-4, 2.5e-6]
|
||||
clip_range: [0.15, 0.025]
|
||||
batch_size: 256
|
||||
n_epochs: 4
|
||||
n_steps: 128
|
||||
autosave_freq: 256
|
||||
time_steps: 512
|
||||
10
agentm/data/game_data/kof98umh.yaml
Normal file
10
agentm/data/game_data/kof98umh.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
████████╗██╗ ██╗███████╗ ██╗ ██╗██╗███╗ ██╗ ██████╗ ██████╗ ███████╗ ███████╗██╗ ██████╗ ██╗ ██╗████████╗███████╗██████╗ ███████╗
|
||||
╚══██╔══╝██║ ██║██╔════╝ ██║ ██╔╝██║████╗ ██║██╔════╝ ██╔═══██╗██╔════╝ ██╔════╝██║██╔════╝ ██║ ██║╚══██╔══╝██╔════╝██╔══██╗██╔════╝
|
||||
██║ ███████║█████╗ █████╔╝ ██║██╔██╗ ██║██║ ███╗ ██║ ██║█████╗ █████╗ ██║██║ ███╗███████║ ██║ █████╗ ██████╔╝███████╗
|
||||
██║ ██╔══██║██╔══╝ ██╔═██╗ ██║██║╚██╗██║██║ ██║ ██║ ██║██╔══╝ ██╔══╝ ██║██║ ██║██╔══██║ ██║ ██╔══╝ ██╔══██╗╚════██║
|
||||
██║ ██║ ██║███████╗ ██║ ██╗██║██║ ╚████║╚██████╔╝ ╚██████╔╝██║ ██║ ██║╚██████╔╝██║ ██║ ██║ ███████╗██║ ██║███████║
|
||||
╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝
|
||||
|
||||
|
||||
10
agentm/data/game_data/mvsc.yaml
Normal file
10
agentm/data/game_data/mvsc.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
███╗ ███╗ █████╗ ██████╗ ██╗ ██╗███████╗██╗ ██╗ ██╗███████╗ ██████╗ █████╗ ██████╗ ██████╗ ██████╗ ███╗ ███╗
|
||||
████╗ ████║██╔══██╗██╔══██╗██║ ██║██╔════╝██║ ██║ ██║██╔════╝ ██╔════╝██╔══██╗██╔══██╗██╔════╝██╔═══██╗████╗ ████║
|
||||
██╔████╔██║███████║██████╔╝██║ ██║█████╗ ██║ ██║ ██║███████╗ ██║ ███████║██████╔╝██║ ██║ ██║██╔████╔██║
|
||||
██║╚██╔╝██║██╔══██║██╔══██╗╚██╗ ██╔╝██╔══╝ ██║ ╚██╗ ██╔╝╚════██║ ██║ ██╔══██║██╔═══╝ ██║ ██║ ██║██║╚██╔╝██║
|
||||
██║ ╚═╝ ██║██║ ██║██║ ██║ ╚████╔╝ ███████╗███████╗ ╚████╔╝ ███████║██╗ ╚██████╗██║ ██║██║ ╚██████╗╚██████╔╝██║ ╚═╝ ██║
|
||||
╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═══╝ ╚══════╝╚══════╝ ╚═══╝ ╚══════╝╚═╝ ╚═════╝╚═╝ ╚═╝╚═╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝
|
||||
|
||||
|
||||
10
agentm/data/game_data/samsh5sp.yaml
Normal file
10
agentm/data/game_data/samsh5sp.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
███████╗ █████╗ ███╗ ███╗██╗ ██╗██████╗ █████╗ ██╗ ███████╗██╗ ██╗ ██████╗ ██████╗ ██████╗ ██╗ ██╗███╗ ██╗ ██╗ ██╗
|
||||
██╔════╝██╔══██╗████╗ ████║██║ ██║██╔══██╗██╔══██╗██║ ██╔════╝██║ ██║██╔═══██╗██╔══██╗██╔═══██╗██║ ██║████╗ ██║ ██║ ██║
|
||||
███████╗███████║██╔████╔██║██║ ██║██████╔╝███████║██║ ███████╗███████║██║ ██║██║ ██║██║ ██║██║ █╗ ██║██╔██╗ ██║ ██║ ██║
|
||||
╚════██║██╔══██║██║╚██╔╝██║██║ ██║██╔══██╗██╔══██║██║ ╚════██║██╔══██║██║ ██║██║ ██║██║ ██║██║███╗██║██║╚██╗██║ ╚██╗ ██╔╝
|
||||
███████║██║ ██║██║ ╚═╝ ██║╚██████╔╝██║ ██║██║ ██║██║ ███████║██║ ██║╚██████╔╝██████╔╝╚██████╔╝╚███╔███╔╝██║ ╚████║ ╚████╔╝
|
||||
╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═════╝ ╚══╝╚══╝ ╚═╝ ╚═══╝ ╚═══╝
|
||||
|
||||
|
||||
88
agentm/data/game_data/sfiii3n.yaml
Normal file
88
agentm/data/game_data/sfiii3n.yaml
Normal file
@ -0,0 +1,88 @@
|
||||
character_mode: single
|
||||
|
||||
character_index:
|
||||
0: Alex
|
||||
1: Twelve
|
||||
2: Hugo
|
||||
3: Sean
|
||||
4: Makoto
|
||||
5: Elena
|
||||
6: Ibuki
|
||||
7: Chun-Li
|
||||
8: Dudley
|
||||
9: Necro
|
||||
10: Q
|
||||
11: Oro
|
||||
12: Urien
|
||||
13: Remy
|
||||
14: Ryu
|
||||
15: Gouki
|
||||
16: Yun
|
||||
17: Yang
|
||||
18: Ken
|
||||
19: Gill
|
||||
|
||||
stun_bar_lengths:
|
||||
Alex: 72
|
||||
Dudley: 72
|
||||
Elena: 64
|
||||
Gill: 72
|
||||
Gouki: 52
|
||||
Hugo: 72
|
||||
Ibuki: 64
|
||||
Ken: 64
|
||||
Makoto: 64
|
||||
Necro: 64
|
||||
Oro: 72
|
||||
Q: 72
|
||||
Remy: 52
|
||||
Ryu: 64
|
||||
Sean: 64
|
||||
Twelve: 64
|
||||
Urien: 64
|
||||
Yang: 64
|
||||
Yun: 64
|
||||
Chun-Li: 64
|
||||
|
||||
stamina_ratings:
|
||||
Alex: 1200
|
||||
Chun-Li: 1100
|
||||
Dudley: 1150
|
||||
Elena: 1100
|
||||
Gill: 1250
|
||||
Gouki: 950
|
||||
Hugo: 1300
|
||||
Ibuki: 1000
|
||||
Ken: 1150
|
||||
Makoto: 1150
|
||||
Necro: 1070
|
||||
Oro: 1100
|
||||
Q: 1200
|
||||
Remy: 1070
|
||||
Ryu: 1150
|
||||
Sean: 1120
|
||||
Twelve: 1050
|
||||
Urien: 1220
|
||||
Yang: 1020
|
||||
Yun: 1020
|
||||
|
||||
wrappers_settings:
|
||||
filter_keys:
|
||||
- own_health
|
||||
- opp_health
|
||||
- own_side
|
||||
- opp_side
|
||||
- own_character
|
||||
- opp_character
|
||||
- own_stun_bar
|
||||
- opp_stun_bar
|
||||
- own_stunned
|
||||
- opp_stunned
|
||||
- own_super_bar
|
||||
- opp_super_bar
|
||||
- own_super_type
|
||||
- opp_super_type
|
||||
- own_super_count
|
||||
- opp_super_count
|
||||
- own_super_max_count
|
||||
- opp_super_max_count
|
||||
10
agentm/data/game_data/soulclbr.yaml
Normal file
10
agentm/data/game_data/soulclbr.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
███████╗ ██████╗ ██╗ ██╗██╗ ██████╗ █████╗ ██╗ ██╗██████╗ ██╗ ██╗██████╗
|
||||
██╔════╝██╔═══██╗██║ ██║██║ ██╔════╝██╔══██╗██║ ██║██╔══██╗██║ ██║██╔══██╗
|
||||
███████╗██║ ██║██║ ██║██║ ██║ ███████║██║ ██║██████╔╝██║ ██║██████╔╝
|
||||
╚════██║██║ ██║██║ ██║██║ ██║ ██╔══██║██║ ██║██╔══██╗██║ ██║██╔══██╗
|
||||
███████║╚██████╔╝╚██████╔╝███████╗ ╚██████╗██║ ██║███████╗██║██████╔╝╚██████╔╝██║ ██║
|
||||
╚══════╝ ╚═════╝ ╚═════╝ ╚══════╝ ╚═════╝╚═╝ ╚═╝╚══════╝╚═╝╚═════╝ ╚═════╝ ╚═╝ ╚═╝
|
||||
|
||||
|
||||
10
agentm/data/game_data/tektagt.yaml
Normal file
10
agentm/data/game_data/tektagt.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
████████╗███████╗██╗ ██╗██╗ ██╗███████╗███╗ ██╗ ████████╗ █████╗ ██████╗ ████████╗ ██████╗ ██╗ ██╗██████╗ ███╗ ██╗ █████╗ ███╗ ███╗███████╗███╗ ██╗████████╗
|
||||
╚══██╔══╝██╔════╝██║ ██╔╝██║ ██╔╝██╔════╝████╗ ██║ ╚══██╔══╝██╔══██╗██╔════╝ ╚══██╔══╝██╔═══██╗██║ ██║██╔══██╗████╗ ██║██╔══██╗████╗ ████║██╔════╝████╗ ██║╚══██╔══╝
|
||||
██║ █████╗ █████╔╝ █████╔╝ █████╗ ██╔██╗ ██║ ██║ ███████║██║ ███╗ ██║ ██║ ██║██║ ██║██████╔╝██╔██╗ ██║███████║██╔████╔██║█████╗ ██╔██╗ ██║ ██║
|
||||
██║ ██╔══╝ ██╔═██╗ ██╔═██╗ ██╔══╝ ██║╚██╗██║ ██║ ██╔══██║██║ ██║ ██║ ██║ ██║██║ ██║██╔══██╗██║╚██╗██║██╔══██║██║╚██╔╝██║██╔══╝ ██║╚██╗██║ ██║
|
||||
██║ ███████╗██║ ██╗██║ ██╗███████╗██║ ╚████║ ██║ ██║ ██║╚██████╔╝ ██║ ╚██████╔╝╚██████╔╝██║ ██║██║ ╚████║██║ ██║██║ ╚═╝ ██║███████╗██║ ╚████║ ██║
|
||||
╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═══╝╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚═╝
|
||||
|
||||
|
||||
10
agentm/data/game_data/umk3.yaml
Normal file
10
agentm/data/game_data/umk3.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
███╗ ███╗ ██████╗ ██████╗ ████████╗ █████╗ ██╗ ██╗ ██╗ ██████╗ ███╗ ███╗██████╗ █████╗ ████████╗ ██████╗
|
||||
████╗ ████║██╔═══██╗██╔══██╗╚══██╔══╝██╔══██╗██║ ██║ ██╔╝██╔═══██╗████╗ ████║██╔══██╗██╔══██╗╚══██╔══╝ ╚════██╗
|
||||
██╔████╔██║██║ ██║██████╔╝ ██║ ███████║██║ █████╔╝ ██║ ██║██╔████╔██║██████╔╝███████║ ██║ █████╔╝
|
||||
██║╚██╔╝██║██║ ██║██╔══██╗ ██║ ██╔══██║██║ ██╔═██╗ ██║ ██║██║╚██╔╝██║██╔══██╗██╔══██║ ██║ ╚═══██╗
|
||||
██║ ╚═╝ ██║╚██████╔╝██║ ██║ ██║ ██║ ██║███████╗ ██║ ██╗╚██████╔╝██║ ╚═╝ ██║██████╔╝██║ ██║ ██║ ██████╔╝
|
||||
╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝
|
||||
|
||||
|
||||
10
agentm/data/game_data/xmvsf.yaml
Normal file
10
agentm/data/game_data/xmvsf.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
██╗ ██╗ ███╗ ███╗███████╗███╗ ██╗ ██╗ ██╗███████╗ ███████╗████████╗██████╗ ███████╗███████╗████████╗ ███████╗██╗ ██████╗ ██╗ ██╗████████╗███████╗██████╗
|
||||
╚██╗██╔╝ ████╗ ████║██╔════╝████╗ ██║ ██║ ██║██╔════╝ ██╔════╝╚══██╔══╝██╔══██╗██╔════╝██╔════╝╚══██╔══╝ ██╔════╝██║██╔════╝ ██║ ██║╚══██╔══╝██╔════╝██╔══██╗
|
||||
╚███╔╝█████╗██╔████╔██║█████╗ ██╔██╗ ██║ ██║ ██║███████╗ ███████╗ ██║ ██████╔╝█████╗ █████╗ ██║ █████╗ ██║██║ ███╗███████║ ██║ █████╗ ██████╔╝
|
||||
██╔██╗╚════╝██║╚██╔╝██║██╔══╝ ██║╚██╗██║ ╚██╗ ██╔╝╚════██║ ╚════██║ ██║ ██╔══██╗██╔══╝ ██╔══╝ ██║ ██╔══╝ ██║██║ ██║██╔══██║ ██║ ██╔══╝ ██╔══██╗
|
||||
██╔╝ ██╗ ██║ ╚═╝ ██║███████╗██║ ╚████║ ╚████╔╝ ███████║ ███████║ ██║ ██║ ██║███████╗███████╗ ██║ ██║ ██║╚██████╔╝██║ ██║ ██║ ███████╗██║ ██║
|
||||
╚═╝ ╚═╝ ╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚═══╝ ╚══════╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚══════╝╚══════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
|
||||
|
||||
|
||||
82
agentm/logic/config_manager.py
Normal file
82
agentm/logic/config_manager.py
Normal file
@ -0,0 +1,82 @@
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
def __init__(self, game_id: str, model_name: str = "default_model"):
|
||||
self.game_id = game_id
|
||||
self.model_name = model_name
|
||||
self.config_base_path = Path("agentm/data/game_data")
|
||||
|
||||
self.general = self._load_yaml("general.yaml")
|
||||
self.game_specific = self._load_yaml(f"{self.game_id}.yaml")
|
||||
|
||||
self.config = self._build_merged_config()
|
||||
|
||||
def _load_yaml(self, filename: str) -> dict:
|
||||
path = self.config_base_path / filename
|
||||
if not path.exists():
|
||||
return {}
|
||||
with open(path, "r") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def _deep_merge(self, source: dict, overrides: dict) -> dict:
|
||||
for key, value in overrides.items():
|
||||
if key in source:
|
||||
if isinstance(source[key], dict) and isinstance(value, dict):
|
||||
self._deep_merge(source[key], value)
|
||||
elif isinstance(source[key], list) and isinstance(value, list):
|
||||
source[key] = list(dict.fromkeys(source[key] + value))
|
||||
else:
|
||||
source[key] = value
|
||||
else:
|
||||
source[key] = value
|
||||
return source
|
||||
|
||||
def _build_merged_config(self) -> dict:
|
||||
folders = self.general.get("folders", {}).copy()
|
||||
folders["model_name"] = self.model_name
|
||||
|
||||
settings = self.general.get("settings", {}).copy()
|
||||
settings["game_id"] = self.game_id
|
||||
settings = self._deep_merge(settings, self.game_specific.get("settings", {}))
|
||||
if "frame_shape" in settings and isinstance(settings["frame_shape"], list):
|
||||
settings["frame_shape"] = tuple(settings["frame_shape"])
|
||||
|
||||
wrappers_settings = self._deep_merge(
|
||||
self.general.get("wrappers_settings", {}).copy(),
|
||||
self.game_specific.get("wrappers_settings", {})
|
||||
)
|
||||
|
||||
policy_kwargs = self._deep_merge(
|
||||
self.general.get("policy_kwargs", {}).copy(),
|
||||
self.game_specific.get("policy_kwargs", {})
|
||||
)
|
||||
|
||||
ppo_settings = self._deep_merge(
|
||||
self.general.get("ppo_settings", {}).copy(),
|
||||
self.game_specific.get("ppo_settings", {})
|
||||
)
|
||||
|
||||
return {
|
||||
"folders": folders,
|
||||
"settings": settings,
|
||||
"wrappers_settings": wrappers_settings,
|
||||
"policy_kwargs": policy_kwargs,
|
||||
"ppo_settings": ppo_settings,
|
||||
}
|
||||
|
||||
def get(self) -> dict:
|
||||
"""Return the full merged config dictionary."""
|
||||
return self.config
|
||||
|
||||
def dump_to_file(self, out_path: Path):
|
||||
"""Write the config to a YAML file."""
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(out_path, "w") as f:
|
||||
yaml.dump(self.config, f, sort_keys=False)
|
||||
|
||||
def dump_to_string(self) -> str:
|
||||
"""Return the YAML string representation of the config."""
|
||||
return yaml.dump(self.config, sort_keys=False)
|
||||
@ -494,3 +494,19 @@ def update_run(
|
||||
f"UPDATE runs SET {', '.join(fields)} WHERE id = ?",
|
||||
tuple(values)
|
||||
)
|
||||
|
||||
def update_run_pending(run_id: int, pending: bool) -> None:
|
||||
"""
|
||||
Update the pending status of a run.
|
||||
|
||||
Args:
|
||||
run_id: The ID of the run to update.
|
||||
pending: New pending status (True/False).
|
||||
"""
|
||||
with get_db_conn() as conn:
|
||||
conn.execute("""
|
||||
UPDATE runs
|
||||
SET pending = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""", (int(pending), datetime.utcnow().isoformat(), run_id))
|
||||
|
||||
@ -322,4 +322,42 @@ Button.game_button {
|
||||
min-height: 3;
|
||||
margin: 1 0;
|
||||
align: center middle;
|
||||
}
|
||||
}
|
||||
|
||||
#config_scroll_container {
|
||||
height: 70vh;
|
||||
overflow-y: auto;
|
||||
padding: 1;
|
||||
border-top: solid {{BORDER}};
|
||||
}
|
||||
|
||||
.top_aligned_layout {
|
||||
layout: vertical;
|
||||
align-horizontal: left;
|
||||
width: 100%;
|
||||
padding: 0 2;
|
||||
}
|
||||
|
||||
.form_column {
|
||||
layout: vertical;
|
||||
width: 100%;
|
||||
padding: 2;
|
||||
height: auto;
|
||||
}
|
||||
|
||||
Input, Select {
|
||||
width: 100%;
|
||||
min-width: 40;
|
||||
}
|
||||
|
||||
.section_label {
|
||||
text-style: bold;
|
||||
padding-top: 1;
|
||||
padding-bottom: 0;
|
||||
}
|
||||
|
||||
.input_row {
|
||||
layout: horizontal;
|
||||
align-vertical: middle;
|
||||
padding: 0 1;
|
||||
}
|
||||
|
||||
@ -322,4 +322,42 @@ Button.game_button {
|
||||
min-height: 3;
|
||||
margin: 1 0;
|
||||
align: center middle;
|
||||
}
|
||||
}
|
||||
|
||||
#config_scroll_container {
|
||||
height: 70vh;
|
||||
overflow-y: auto;
|
||||
padding: 1;
|
||||
border-top: solid #3a9bed;
|
||||
}
|
||||
|
||||
.top_aligned_layout {
|
||||
layout: vertical;
|
||||
align-horizontal: left;
|
||||
width: 100%;
|
||||
padding: 0 2;
|
||||
}
|
||||
|
||||
.form_column {
|
||||
layout: vertical;
|
||||
width: 100%;
|
||||
padding: 2;
|
||||
height: auto;
|
||||
}
|
||||
|
||||
Input, Select {
|
||||
width: 100%;
|
||||
min-width: 40;
|
||||
}
|
||||
|
||||
.section_label {
|
||||
text-style: bold;
|
||||
padding-top: 1;
|
||||
padding-bottom: 0;
|
||||
}
|
||||
|
||||
.input_row {
|
||||
layout: horizontal;
|
||||
align-vertical: middle;
|
||||
padding: 0 1;
|
||||
}
|
||||
|
||||
@ -62,6 +62,7 @@ class AgentHomeView(Screen):
|
||||
await self.app.push_screen(
|
||||
ModelSelectView(
|
||||
agent_metadata=self.agent_metadata,
|
||||
mode=action
|
||||
mode=action,
|
||||
run=self.run_metadata
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -42,11 +42,11 @@ class ModelSelectView(Screen):
|
||||
("r", "refresh_models", "Refresh"),
|
||||
]
|
||||
|
||||
def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"]):
|
||||
def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"], run: dict):
|
||||
super().__init__()
|
||||
self.agent_metadata = agent_metadata
|
||||
self.mode = mode
|
||||
self.selected_model = None
|
||||
self.run_metadata = run
|
||||
|
||||
def compose(self):
|
||||
header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_model.txt"
|
||||
@ -132,7 +132,7 @@ class ModelSelectView(Screen):
|
||||
log_with_caller("info", f"Model confirmed: {self.selected_model['name']} for mode={self.mode}")
|
||||
match self.mode:
|
||||
case "train":
|
||||
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=self.selected_model))
|
||||
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None, run=self.run_metadata))
|
||||
case "eval":
|
||||
await self.app.push_screen(EvaluationView(agent=self.agent_metadata, model=self.selected_model))
|
||||
case "submit":
|
||||
@ -140,7 +140,10 @@ class ModelSelectView(Screen):
|
||||
|
||||
elif event.button.id == "create_model_btn":
|
||||
log_with_caller("info", f"Creating new model for agent: {self.agent_metadata['name']}")
|
||||
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None)) # New model starts from training
|
||||
await self.app.push_screen(
|
||||
TrainingView(agent=self.agent_metadata, model=None, run=self.run_metadata)
|
||||
)
|
||||
|
||||
|
||||
async def action_refresh_models(self):
|
||||
await self.refresh_model_list()
|
||||
|
||||
@ -1,106 +1,152 @@
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Static, Input, Button
|
||||
from textual.containers import Vertical, Horizontal
|
||||
from agentm.theme.palette import get_theme
|
||||
from textual.widgets import Static, Input, Button, Checkbox, Label, Select
|
||||
from textual.containers import Vertical, VerticalScroll, Horizontal, Grid
|
||||
from agentm.components.footer import AgentMFooter
|
||||
from agentm.utils.logger import log_with_caller
|
||||
from agentm.logic.db_functions import insert_model, update_run_pending
|
||||
from datetime import datetime
|
||||
from agentm.theme.palette import get_theme
|
||||
from pathlib import Path
|
||||
|
||||
palette = get_theme()
|
||||
HEADER_PATH = Path("agentm/assets/headers/training_setup.txt")
|
||||
|
||||
class TrainingView(Screen):
|
||||
BINDINGS = [
|
||||
("escape", "app.pop_screen", "Back")
|
||||
]
|
||||
BINDINGS = [("escape", "app.pop_screen", "Back")]
|
||||
|
||||
def __init__(self, agent: dict, model: dict | None, run: dict):
|
||||
super().__init__()
|
||||
self.agent_metadata = agent
|
||||
self.model_metadata = model
|
||||
self.run_metadata = run
|
||||
self.is_pending_run = run.get("pending", True)
|
||||
|
||||
log_with_caller("info", f"Initialized TrainingView for agent: {agent['name']}")
|
||||
|
||||
self.model_name_input = Input(id="model_name")
|
||||
self.num_env_input = Input(id="num_env")
|
||||
|
||||
self.framework_dropdown = Select(
|
||||
id="framework",
|
||||
options=[("SB3", "SB3"), ("RAY", "RAY"), ("SHEEPRL", "SHEEPRL")],
|
||||
value="SB3"
|
||||
)
|
||||
self.algo_dropdown = Select(id="algo", options=[("PPO", "PPO")], value="PPO")
|
||||
|
||||
self.settings_inputs = {
|
||||
"step_ratio": Input(value="6", id="step_ratio"),
|
||||
"frame_shape": Input(value="[128, 128, 1]", id="frame_shape"),
|
||||
"continue_game": Input(value="0.0", id="continue_game"),
|
||||
"action_space": Input(value="discrete", id="action_space")
|
||||
}
|
||||
|
||||
self.wrapper_checkboxes = {
|
||||
"normalize_reward": Checkbox(label="Normalize Reward", value=True, id="normalize_reward"),
|
||||
"no_attack_buttons_combinations": Checkbox(label="No Attack Combos", value=True, id="no_attack_buttons_combinations"),
|
||||
"add_last_action": Checkbox(label="Add Last Action", value=True, id="add_last_action"),
|
||||
"scale": Checkbox(label="Scale", value=True, id="scale"),
|
||||
"exclude_image_scaling": Checkbox(label="Exclude Image Scaling", value=True, id="exclude_image_scaling"),
|
||||
"role_relative": Checkbox(label="Role Relative", value=True, id="role_relative"),
|
||||
"flatten": Checkbox(label="Flatten", value=True, id="flatten")
|
||||
}
|
||||
|
||||
self.wrapper_inputs = {
|
||||
"stack_frames": Input(value="4", id="stack_frames"),
|
||||
"stack_actions": Input(value="12", id="stack_actions"),
|
||||
"dilation": Input(value="1", id="dilation")
|
||||
}
|
||||
|
||||
self.filter_key_checkboxes = [
|
||||
Checkbox(label=key, value=True, id=f"fk_{key}") for key in ["frame", "action", "stage", "timer"]
|
||||
]
|
||||
|
||||
self.net_arch_input = Input(value="[64, 64]", id="net_arch")
|
||||
|
||||
self.ppo_inputs = {
|
||||
"gamma": Input(value="0.94", id="gamma"),
|
||||
"model_checkpoint": Input(value="0", id="model_checkpoint"),
|
||||
"learning_rate_start": Input(value="0.00025", id="lr_start"),
|
||||
"learning_rate_end": Input(value="0.0000025", id="lr_end"),
|
||||
"clip_range_start": Input(value="0.15", id="clip_start"),
|
||||
"clip_range_end": Input(value="0.025", id="clip_end"),
|
||||
"batch_size": Input(value="256", id="batch_size"),
|
||||
"n_epochs": Input(value="4", id="n_epochs"),
|
||||
"n_steps": Input(value="128", id="n_steps"),
|
||||
"autosave_freq": Input(value="256", id="autosave_freq"),
|
||||
"time_steps": Input(value="512", id="time_steps")
|
||||
}
|
||||
|
||||
def compose(self):
|
||||
log_with_caller("info", f"Opening TrainingView for agent='{self.agent_metadata['name']}', run='{self.run_metadata['name']}', pending={self.is_pending_run}")
|
||||
log_with_caller("debug", "Composing TrainingView layout...")
|
||||
|
||||
yield Static(f"[{palette.ACCENT} bold]Training: {self.agent_metadata['name']}[/]", classes="header")
|
||||
try:
|
||||
header_text = HEADER_PATH.read_text()
|
||||
except FileNotFoundError:
|
||||
log_with_caller("error", f"Missing header file at: {HEADER_PATH}")
|
||||
header_text = "[bold red]Missing Header File"
|
||||
|
||||
if self.is_pending_run:
|
||||
yield self.render_pending_setup()
|
||||
else:
|
||||
yield self.render_resume_options()
|
||||
yield Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header")
|
||||
|
||||
def render_pending_setup(self):
|
||||
yield Static("[b]Initial Model Setup[/b]", classes="subheader")
|
||||
# Build scrollable form content
|
||||
yield VerticalScroll(
|
||||
Vertical(
|
||||
|
||||
self.name_input = Input(placeholder="Model Name", id="model_name")
|
||||
self.steps_input = Input(placeholder="Total Training Steps", id="total_steps")
|
||||
self.lr_input = Input(placeholder="Initial Learning Rate", id="learning_rate")
|
||||
self.clip_input = Input(placeholder="Initial Clip Range", id="clip_range")
|
||||
self.notes_input = Input(placeholder="Notes", id="notes")
|
||||
# Framework & Algorithm
|
||||
Static("Framework & Algorithm", classes="section_label"),
|
||||
Horizontal(
|
||||
Vertical(Label("Framework"), self.framework_dropdown),
|
||||
Vertical(Label("Algorithm"), self.algo_dropdown),
|
||||
classes="input_row"
|
||||
),
|
||||
|
||||
self.confirm_button = Button("✅ Create & Start Training", id="confirm_model_btn", classes="confirm_button")
|
||||
# Model name / num_env
|
||||
Static("Model Configuration", classes="section_label"),
|
||||
Horizontal(
|
||||
Vertical(Label("Model Name"), self.model_name_input),
|
||||
Vertical(Label("Num Envs"), self.num_env_input),
|
||||
classes="input_row"
|
||||
),
|
||||
|
||||
return Vertical(
|
||||
self.name_input,
|
||||
self.steps_input,
|
||||
self.lr_input,
|
||||
self.clip_input,
|
||||
self.notes_input,
|
||||
self.confirm_button,
|
||||
classes="centered_layout"
|
||||
# General Settings
|
||||
Static("General Settings", classes="section_label"),
|
||||
*[
|
||||
Horizontal(Vertical(Label(k)), v, classes="input_row")
|
||||
for k, v in self.settings_inputs.items()
|
||||
],
|
||||
|
||||
# Wrapper Checkboxes
|
||||
Static("Wrapper Flags", classes="section_label"),
|
||||
*[
|
||||
Horizontal(cb, classes="input_row")
|
||||
for cb in self.wrapper_checkboxes.values()
|
||||
],
|
||||
|
||||
# Wrapper Inputs
|
||||
*[
|
||||
Horizontal(Vertical(Label(k.replace("_", " ").title())), v, classes="input_row")
|
||||
for k, v in self.wrapper_inputs.items()
|
||||
],
|
||||
|
||||
# Filter Keys
|
||||
Static("Filter Keys", classes="section_label"),
|
||||
Grid(*self.filter_key_checkboxes, id="filter_grid"),
|
||||
|
||||
# Policy kwargs
|
||||
Static("Policy Architecture", classes="section_label"),
|
||||
Horizontal(Label("Net Arch"), self.net_arch_input, classes="input_row"),
|
||||
|
||||
# PPO Settings
|
||||
Static("PPO Settings", classes="section_label"),
|
||||
*[
|
||||
Horizontal(Label(k.replace("_", " ").title()), v, classes="input_row")
|
||||
for k, v in self.ppo_inputs.items()
|
||||
],
|
||||
|
||||
id="form_layout",
|
||||
classes="form_column"
|
||||
),
|
||||
id="config_scroll_container"
|
||||
)
|
||||
|
||||
def render_resume_options(self):
|
||||
model = self.model_metadata
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
# Final button and footer
|
||||
yield Button("🔍 Review Config", id="review_config_btn", classes="confirm_button")
|
||||
yield AgentMFooter(compact=True)
|
||||
|
||||
table = Table.grid(padding=(0, 1))
|
||||
table.add_column("Key", style="bold underline")
|
||||
table.add_column("Value", style=palette.ACCENT, overflow="fold")
|
||||
|
||||
table.add_row("Name", model["name"])
|
||||
table.add_row("Steps", f"{model['total_steps_completed']} / {model['total_steps_planned']}")
|
||||
table.add_row("Reward", str(model.get("average_reward", "—")))
|
||||
table.add_row("Learning Rate", str(model.get("current_learning_rate", "—")))
|
||||
table.add_row("Clip Range", str(model.get("current_clip_range", "—")))
|
||||
table.add_row("Created", model.get("created_at", "—"))
|
||||
|
||||
return Vertical(
|
||||
Static(Panel(table, title="Model Info", border_style=palette.BORDER), classes="agent_info_box"),
|
||||
Button("⏯️ Resume Training", id="resume_training_btn", classes="confirm_button"),
|
||||
classes="centered_layout"
|
||||
)
|
||||
|
||||
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
if event.button.id == "confirm_model_btn":
|
||||
try:
|
||||
name = self.name_input.value.strip()
|
||||
total_steps = int(self.steps_input.value)
|
||||
learning_rate = float(self.lr_input.value)
|
||||
clip_range = float(self.clip_input.value)
|
||||
notes = self.notes_input.value.strip()
|
||||
|
||||
log_with_caller("info", f"Creating new model: {name} for agent_id={self.agent_metadata['id']}")
|
||||
|
||||
insert_model(
|
||||
agent_id=self.agent_metadata["id"],
|
||||
name=name,
|
||||
total_steps_planned=total_steps,
|
||||
current_learning_rate=learning_rate,
|
||||
current_clip_range=clip_range,
|
||||
notes=notes
|
||||
)
|
||||
|
||||
update_run_pending(self.run_metadata["id"], False)
|
||||
log_with_caller("info", f"Model '{name}' created and run marked as not pending")
|
||||
|
||||
await self.app.pop_screen()
|
||||
|
||||
except Exception as e:
|
||||
log_with_caller("error", f"Failed to create model: {e}")
|
||||
|
||||
elif event.button.id == "resume_training_btn":
|
||||
log_with_caller("info", f"Resuming training for model: {self.model_metadata['name']}")
|
||||
await self.app.pop_screen()
|
||||
log_with_caller("debug", "Finished composing TrainingView")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user