styling on training added.
This commit is contained in:
parent
4500bfd388
commit
fb6e2fcd65
10
agentm/assets/headers/training_setup.txt
Normal file
10
agentm/assets/headers/training_setup.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
████████╗██████╗ █████╗ ██╗███╗ ██╗██╗███╗ ██╗ ██████╗ ███████╗███████╗████████╗██╗ ██╗██████╗
|
||||||
|
╚══██╔══╝██╔══██╗██╔══██╗██║████╗ ██║██║████╗ ██║██╔════╝ ██╔════╝██╔════╝╚══██╔══╝██║ ██║██╔══██╗
|
||||||
|
██║ ██████╔╝███████║██║██╔██╗ ██║██║██╔██╗ ██║██║ ███╗ ███████╗█████╗ ██║ ██║ ██║██████╔╝
|
||||||
|
██║ ██╔══██╗██╔══██║██║██║╚██╗██║██║██║╚██╗██║██║ ██║ ╚════██║██╔══╝ ██║ ██║ ██║██╔═══╝
|
||||||
|
██║ ██║ ██║██║ ██║██║██║ ╚████║██║██║ ╚████║╚██████╔╝ ███████║███████╗ ██║ ╚██████╔╝██║
|
||||||
|
╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚══════╝ ╚═╝ ╚═════╝ ╚═╝
|
||||||
|
|
||||||
|
|
||||||
10
agentm/data/game_data/doapp.yaml
Normal file
10
agentm/data/game_data/doapp.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
██████╗ ███████╗ █████╗ ██████╗ ██████╗ ██████╗ █████╗ ██╗ ██╗██╗ ██╗███████╗
|
||||||
|
██╔══██╗██╔════╝██╔══██╗██╔══██╗ ██╔═══██╗██╔══██╗ ██╔══██╗██║ ██║██║ ██║██╔════╝
|
||||||
|
██║ ██║█████╗ ███████║██║ ██║ ██║ ██║██████╔╝ ███████║██║ ██║██║ ██║█████╗
|
||||||
|
██║ ██║██╔══╝ ██╔══██║██║ ██║ ██║ ██║██╔══██╗ ██╔══██║██║ ██║╚██╗ ██╔╝██╔══╝
|
||||||
|
██████╔╝███████╗██║ ██║██████╔╝ ╚██████╔╝██║ ██║ ██║ ██║███████╗██║ ╚████╔╝ ███████╗
|
||||||
|
╚═════╝ ╚══════╝╚═╝ ╚═╝╚═════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚══════╝
|
||||||
|
|
||||||
|
|
||||||
41
agentm/data/game_data/general.yaml
Normal file
41
agentm/data/game_data/general.yaml
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
folders:
|
||||||
|
parent_dir: "./results/"
|
||||||
|
model_name: "default_model"
|
||||||
|
|
||||||
|
settings:
|
||||||
|
step_ratio: 6
|
||||||
|
frame_shape: [128, 128, 1]
|
||||||
|
continue_game: 0.0
|
||||||
|
action_space: "discrete"
|
||||||
|
|
||||||
|
wrappers_settings:
|
||||||
|
wrappers:
|
||||||
|
normalize_reward: true
|
||||||
|
no_attack_buttons_combinations: true
|
||||||
|
stack_frames: 4
|
||||||
|
dilation: 1
|
||||||
|
add_last_action: true
|
||||||
|
stack_actions: 12
|
||||||
|
scale: true
|
||||||
|
exclude_image_scaling: true
|
||||||
|
role_relative: true
|
||||||
|
flatten: true
|
||||||
|
filter_keys:
|
||||||
|
- frame
|
||||||
|
- action
|
||||||
|
- stage
|
||||||
|
- timer
|
||||||
|
|
||||||
|
policy_kwargs:
|
||||||
|
net_arch: [64, 64]
|
||||||
|
|
||||||
|
ppo_settings:
|
||||||
|
gamma: 0.94
|
||||||
|
model_checkpoint: "0"
|
||||||
|
learning_rate: [2.5e-4, 2.5e-6]
|
||||||
|
clip_range: [0.15, 0.025]
|
||||||
|
batch_size: 256
|
||||||
|
n_epochs: 4
|
||||||
|
n_steps: 128
|
||||||
|
autosave_freq: 256
|
||||||
|
time_steps: 512
|
||||||
10
agentm/data/game_data/kof98umh.yaml
Normal file
10
agentm/data/game_data/kof98umh.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
████████╗██╗ ██╗███████╗ ██╗ ██╗██╗███╗ ██╗ ██████╗ ██████╗ ███████╗ ███████╗██╗ ██████╗ ██╗ ██╗████████╗███████╗██████╗ ███████╗
|
||||||
|
╚══██╔══╝██║ ██║██╔════╝ ██║ ██╔╝██║████╗ ██║██╔════╝ ██╔═══██╗██╔════╝ ██╔════╝██║██╔════╝ ██║ ██║╚══██╔══╝██╔════╝██╔══██╗██╔════╝
|
||||||
|
██║ ███████║█████╗ █████╔╝ ██║██╔██╗ ██║██║ ███╗ ██║ ██║█████╗ █████╗ ██║██║ ███╗███████║ ██║ █████╗ ██████╔╝███████╗
|
||||||
|
██║ ██╔══██║██╔══╝ ██╔═██╗ ██║██║╚██╗██║██║ ██║ ██║ ██║██╔══╝ ██╔══╝ ██║██║ ██║██╔══██║ ██║ ██╔══╝ ██╔══██╗╚════██║
|
||||||
|
██║ ██║ ██║███████╗ ██║ ██╗██║██║ ╚████║╚██████╔╝ ╚██████╔╝██║ ██║ ██║╚██████╔╝██║ ██║ ██║ ███████╗██║ ██║███████║
|
||||||
|
╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝
|
||||||
|
|
||||||
|
|
||||||
10
agentm/data/game_data/mvsc.yaml
Normal file
10
agentm/data/game_data/mvsc.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
███╗ ███╗ █████╗ ██████╗ ██╗ ██╗███████╗██╗ ██╗ ██╗███████╗ ██████╗ █████╗ ██████╗ ██████╗ ██████╗ ███╗ ███╗
|
||||||
|
████╗ ████║██╔══██╗██╔══██╗██║ ██║██╔════╝██║ ██║ ██║██╔════╝ ██╔════╝██╔══██╗██╔══██╗██╔════╝██╔═══██╗████╗ ████║
|
||||||
|
██╔████╔██║███████║██████╔╝██║ ██║█████╗ ██║ ██║ ██║███████╗ ██║ ███████║██████╔╝██║ ██║ ██║██╔████╔██║
|
||||||
|
██║╚██╔╝██║██╔══██║██╔══██╗╚██╗ ██╔╝██╔══╝ ██║ ╚██╗ ██╔╝╚════██║ ██║ ██╔══██║██╔═══╝ ██║ ██║ ██║██║╚██╔╝██║
|
||||||
|
██║ ╚═╝ ██║██║ ██║██║ ██║ ╚████╔╝ ███████╗███████╗ ╚████╔╝ ███████║██╗ ╚██████╗██║ ██║██║ ╚██████╗╚██████╔╝██║ ╚═╝ ██║
|
||||||
|
╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═══╝ ╚══════╝╚══════╝ ╚═══╝ ╚══════╝╚═╝ ╚═════╝╚═╝ ╚═╝╚═╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝
|
||||||
|
|
||||||
|
|
||||||
10
agentm/data/game_data/samsh5sp.yaml
Normal file
10
agentm/data/game_data/samsh5sp.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
███████╗ █████╗ ███╗ ███╗██╗ ██╗██████╗ █████╗ ██╗ ███████╗██╗ ██╗ ██████╗ ██████╗ ██████╗ ██╗ ██╗███╗ ██╗ ██╗ ██╗
|
||||||
|
██╔════╝██╔══██╗████╗ ████║██║ ██║██╔══██╗██╔══██╗██║ ██╔════╝██║ ██║██╔═══██╗██╔══██╗██╔═══██╗██║ ██║████╗ ██║ ██║ ██║
|
||||||
|
███████╗███████║██╔████╔██║██║ ██║██████╔╝███████║██║ ███████╗███████║██║ ██║██║ ██║██║ ██║██║ █╗ ██║██╔██╗ ██║ ██║ ██║
|
||||||
|
╚════██║██╔══██║██║╚██╔╝██║██║ ██║██╔══██╗██╔══██║██║ ╚════██║██╔══██║██║ ██║██║ ██║██║ ██║██║███╗██║██║╚██╗██║ ╚██╗ ██╔╝
|
||||||
|
███████║██║ ██║██║ ╚═╝ ██║╚██████╔╝██║ ██║██║ ██║██║ ███████║██║ ██║╚██████╔╝██████╔╝╚██████╔╝╚███╔███╔╝██║ ╚████║ ╚████╔╝
|
||||||
|
╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═════╝ ╚══╝╚══╝ ╚═╝ ╚═══╝ ╚═══╝
|
||||||
|
|
||||||
|
|
||||||
88
agentm/data/game_data/sfiii3n.yaml
Normal file
88
agentm/data/game_data/sfiii3n.yaml
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
character_mode: single
|
||||||
|
|
||||||
|
character_index:
|
||||||
|
0: Alex
|
||||||
|
1: Twelve
|
||||||
|
2: Hugo
|
||||||
|
3: Sean
|
||||||
|
4: Makoto
|
||||||
|
5: Elena
|
||||||
|
6: Ibuki
|
||||||
|
7: Chun-Li
|
||||||
|
8: Dudley
|
||||||
|
9: Necro
|
||||||
|
10: Q
|
||||||
|
11: Oro
|
||||||
|
12: Urien
|
||||||
|
13: Remy
|
||||||
|
14: Ryu
|
||||||
|
15: Gouki
|
||||||
|
16: Yun
|
||||||
|
17: Yang
|
||||||
|
18: Ken
|
||||||
|
19: Gill
|
||||||
|
|
||||||
|
stun_bar_lengths:
|
||||||
|
Alex: 72
|
||||||
|
Dudley: 72
|
||||||
|
Elena: 64
|
||||||
|
Gill: 72
|
||||||
|
Gouki: 52
|
||||||
|
Hugo: 72
|
||||||
|
Ibuki: 64
|
||||||
|
Ken: 64
|
||||||
|
Makoto: 64
|
||||||
|
Necro: 64
|
||||||
|
Oro: 72
|
||||||
|
Q: 72
|
||||||
|
Remy: 52
|
||||||
|
Ryu: 64
|
||||||
|
Sean: 64
|
||||||
|
Twelve: 64
|
||||||
|
Urien: 64
|
||||||
|
Yang: 64
|
||||||
|
Yun: 64
|
||||||
|
Chun-Li: 64
|
||||||
|
|
||||||
|
stamina_ratings:
|
||||||
|
Alex: 1200
|
||||||
|
Chun-Li: 1100
|
||||||
|
Dudley: 1150
|
||||||
|
Elena: 1100
|
||||||
|
Gill: 1250
|
||||||
|
Gouki: 950
|
||||||
|
Hugo: 1300
|
||||||
|
Ibuki: 1000
|
||||||
|
Ken: 1150
|
||||||
|
Makoto: 1150
|
||||||
|
Necro: 1070
|
||||||
|
Oro: 1100
|
||||||
|
Q: 1200
|
||||||
|
Remy: 1070
|
||||||
|
Ryu: 1150
|
||||||
|
Sean: 1120
|
||||||
|
Twelve: 1050
|
||||||
|
Urien: 1220
|
||||||
|
Yang: 1020
|
||||||
|
Yun: 1020
|
||||||
|
|
||||||
|
wrappers_settings:
|
||||||
|
filter_keys:
|
||||||
|
- own_health
|
||||||
|
- opp_health
|
||||||
|
- own_side
|
||||||
|
- opp_side
|
||||||
|
- own_character
|
||||||
|
- opp_character
|
||||||
|
- own_stun_bar
|
||||||
|
- opp_stun_bar
|
||||||
|
- own_stunned
|
||||||
|
- opp_stunned
|
||||||
|
- own_super_bar
|
||||||
|
- opp_super_bar
|
||||||
|
- own_super_type
|
||||||
|
- opp_super_type
|
||||||
|
- own_super_count
|
||||||
|
- opp_super_count
|
||||||
|
- own_super_max_count
|
||||||
|
- opp_super_max_count
|
||||||
10
agentm/data/game_data/soulclbr.yaml
Normal file
10
agentm/data/game_data/soulclbr.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
███████╗ ██████╗ ██╗ ██╗██╗ ██████╗ █████╗ ██╗ ██╗██████╗ ██╗ ██╗██████╗
|
||||||
|
██╔════╝██╔═══██╗██║ ██║██║ ██╔════╝██╔══██╗██║ ██║██╔══██╗██║ ██║██╔══██╗
|
||||||
|
███████╗██║ ██║██║ ██║██║ ██║ ███████║██║ ██║██████╔╝██║ ██║██████╔╝
|
||||||
|
╚════██║██║ ██║██║ ██║██║ ██║ ██╔══██║██║ ██║██╔══██╗██║ ██║██╔══██╗
|
||||||
|
███████║╚██████╔╝╚██████╔╝███████╗ ╚██████╗██║ ██║███████╗██║██████╔╝╚██████╔╝██║ ██║
|
||||||
|
╚══════╝ ╚═════╝ ╚═════╝ ╚══════╝ ╚═════╝╚═╝ ╚═╝╚══════╝╚═╝╚═════╝ ╚═════╝ ╚═╝ ╚═╝
|
||||||
|
|
||||||
|
|
||||||
10
agentm/data/game_data/tektagt.yaml
Normal file
10
agentm/data/game_data/tektagt.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
████████╗███████╗██╗ ██╗██╗ ██╗███████╗███╗ ██╗ ████████╗ █████╗ ██████╗ ████████╗ ██████╗ ██╗ ██╗██████╗ ███╗ ██╗ █████╗ ███╗ ███╗███████╗███╗ ██╗████████╗
|
||||||
|
╚══██╔══╝██╔════╝██║ ██╔╝██║ ██╔╝██╔════╝████╗ ██║ ╚══██╔══╝██╔══██╗██╔════╝ ╚══██╔══╝██╔═══██╗██║ ██║██╔══██╗████╗ ██║██╔══██╗████╗ ████║██╔════╝████╗ ██║╚══██╔══╝
|
||||||
|
██║ █████╗ █████╔╝ █████╔╝ █████╗ ██╔██╗ ██║ ██║ ███████║██║ ███╗ ██║ ██║ ██║██║ ██║██████╔╝██╔██╗ ██║███████║██╔████╔██║█████╗ ██╔██╗ ██║ ██║
|
||||||
|
██║ ██╔══╝ ██╔═██╗ ██╔═██╗ ██╔══╝ ██║╚██╗██║ ██║ ██╔══██║██║ ██║ ██║ ██║ ██║██║ ██║██╔══██╗██║╚██╗██║██╔══██║██║╚██╔╝██║██╔══╝ ██║╚██╗██║ ██║
|
||||||
|
██║ ███████╗██║ ██╗██║ ██╗███████╗██║ ╚████║ ██║ ██║ ██║╚██████╔╝ ██║ ╚██████╔╝╚██████╔╝██║ ██║██║ ╚████║██║ ██║██║ ╚═╝ ██║███████╗██║ ╚████║ ██║
|
||||||
|
╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═══╝╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚═╝
|
||||||
|
|
||||||
|
|
||||||
10
agentm/data/game_data/umk3.yaml
Normal file
10
agentm/data/game_data/umk3.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
███╗ ███╗ ██████╗ ██████╗ ████████╗ █████╗ ██╗ ██╗ ██╗ ██████╗ ███╗ ███╗██████╗ █████╗ ████████╗ ██████╗
|
||||||
|
████╗ ████║██╔═══██╗██╔══██╗╚══██╔══╝██╔══██╗██║ ██║ ██╔╝██╔═══██╗████╗ ████║██╔══██╗██╔══██╗╚══██╔══╝ ╚════██╗
|
||||||
|
██╔████╔██║██║ ██║██████╔╝ ██║ ███████║██║ █████╔╝ ██║ ██║██╔████╔██║██████╔╝███████║ ██║ █████╔╝
|
||||||
|
██║╚██╔╝██║██║ ██║██╔══██╗ ██║ ██╔══██║██║ ██╔═██╗ ██║ ██║██║╚██╔╝██║██╔══██╗██╔══██║ ██║ ╚═══██╗
|
||||||
|
██║ ╚═╝ ██║╚██████╔╝██║ ██║ ██║ ██║ ██║███████╗ ██║ ██╗╚██████╔╝██║ ╚═╝ ██║██████╔╝██║ ██║ ██║ ██████╔╝
|
||||||
|
╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝
|
||||||
|
|
||||||
|
|
||||||
10
agentm/data/game_data/xmvsf.yaml
Normal file
10
agentm/data/game_data/xmvsf.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
|
||||||
|
██╗ ██╗ ███╗ ███╗███████╗███╗ ██╗ ██╗ ██╗███████╗ ███████╗████████╗██████╗ ███████╗███████╗████████╗ ███████╗██╗ ██████╗ ██╗ ██╗████████╗███████╗██████╗
|
||||||
|
╚██╗██╔╝ ████╗ ████║██╔════╝████╗ ██║ ██║ ██║██╔════╝ ██╔════╝╚══██╔══╝██╔══██╗██╔════╝██╔════╝╚══██╔══╝ ██╔════╝██║██╔════╝ ██║ ██║╚══██╔══╝██╔════╝██╔══██╗
|
||||||
|
╚███╔╝█████╗██╔████╔██║█████╗ ██╔██╗ ██║ ██║ ██║███████╗ ███████╗ ██║ ██████╔╝█████╗ █████╗ ██║ █████╗ ██║██║ ███╗███████║ ██║ █████╗ ██████╔╝
|
||||||
|
██╔██╗╚════╝██║╚██╔╝██║██╔══╝ ██║╚██╗██║ ╚██╗ ██╔╝╚════██║ ╚════██║ ██║ ██╔══██╗██╔══╝ ██╔══╝ ██║ ██╔══╝ ██║██║ ██║██╔══██║ ██║ ██╔══╝ ██╔══██╗
|
||||||
|
██╔╝ ██╗ ██║ ╚═╝ ██║███████╗██║ ╚████║ ╚████╔╝ ███████║ ███████║ ██║ ██║ ██║███████╗███████╗ ██║ ██║ ██║╚██████╔╝██║ ██║ ██║ ███████╗██║ ██║
|
||||||
|
╚═╝ ╚═╝ ╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚═══╝ ╚══════╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚══════╝╚══════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝
|
||||||
|
|
||||||
|
|
||||||
82
agentm/logic/config_manager.py
Normal file
82
agentm/logic/config_manager.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigManager:
|
||||||
|
def __init__(self, game_id: str, model_name: str = "default_model"):
|
||||||
|
self.game_id = game_id
|
||||||
|
self.model_name = model_name
|
||||||
|
self.config_base_path = Path("agentm/data/game_data")
|
||||||
|
|
||||||
|
self.general = self._load_yaml("general.yaml")
|
||||||
|
self.game_specific = self._load_yaml(f"{self.game_id}.yaml")
|
||||||
|
|
||||||
|
self.config = self._build_merged_config()
|
||||||
|
|
||||||
|
def _load_yaml(self, filename: str) -> dict:
|
||||||
|
path = self.config_base_path / filename
|
||||||
|
if not path.exists():
|
||||||
|
return {}
|
||||||
|
with open(path, "r") as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
def _deep_merge(self, source: dict, overrides: dict) -> dict:
|
||||||
|
for key, value in overrides.items():
|
||||||
|
if key in source:
|
||||||
|
if isinstance(source[key], dict) and isinstance(value, dict):
|
||||||
|
self._deep_merge(source[key], value)
|
||||||
|
elif isinstance(source[key], list) and isinstance(value, list):
|
||||||
|
source[key] = list(dict.fromkeys(source[key] + value))
|
||||||
|
else:
|
||||||
|
source[key] = value
|
||||||
|
else:
|
||||||
|
source[key] = value
|
||||||
|
return source
|
||||||
|
|
||||||
|
def _build_merged_config(self) -> dict:
|
||||||
|
folders = self.general.get("folders", {}).copy()
|
||||||
|
folders["model_name"] = self.model_name
|
||||||
|
|
||||||
|
settings = self.general.get("settings", {}).copy()
|
||||||
|
settings["game_id"] = self.game_id
|
||||||
|
settings = self._deep_merge(settings, self.game_specific.get("settings", {}))
|
||||||
|
if "frame_shape" in settings and isinstance(settings["frame_shape"], list):
|
||||||
|
settings["frame_shape"] = tuple(settings["frame_shape"])
|
||||||
|
|
||||||
|
wrappers_settings = self._deep_merge(
|
||||||
|
self.general.get("wrappers_settings", {}).copy(),
|
||||||
|
self.game_specific.get("wrappers_settings", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
policy_kwargs = self._deep_merge(
|
||||||
|
self.general.get("policy_kwargs", {}).copy(),
|
||||||
|
self.game_specific.get("policy_kwargs", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
ppo_settings = self._deep_merge(
|
||||||
|
self.general.get("ppo_settings", {}).copy(),
|
||||||
|
self.game_specific.get("ppo_settings", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"folders": folders,
|
||||||
|
"settings": settings,
|
||||||
|
"wrappers_settings": wrappers_settings,
|
||||||
|
"policy_kwargs": policy_kwargs,
|
||||||
|
"ppo_settings": ppo_settings,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get(self) -> dict:
|
||||||
|
"""Return the full merged config dictionary."""
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
def dump_to_file(self, out_path: Path):
|
||||||
|
"""Write the config to a YAML file."""
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(out_path, "w") as f:
|
||||||
|
yaml.dump(self.config, f, sort_keys=False)
|
||||||
|
|
||||||
|
def dump_to_string(self) -> str:
|
||||||
|
"""Return the YAML string representation of the config."""
|
||||||
|
return yaml.dump(self.config, sort_keys=False)
|
||||||
@ -494,3 +494,19 @@ def update_run(
|
|||||||
f"UPDATE runs SET {', '.join(fields)} WHERE id = ?",
|
f"UPDATE runs SET {', '.join(fields)} WHERE id = ?",
|
||||||
tuple(values)
|
tuple(values)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_run_pending(run_id: int, pending: bool) -> None:
|
||||||
|
"""
|
||||||
|
Update the pending status of a run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: The ID of the run to update.
|
||||||
|
pending: New pending status (True/False).
|
||||||
|
"""
|
||||||
|
with get_db_conn() as conn:
|
||||||
|
conn.execute("""
|
||||||
|
UPDATE runs
|
||||||
|
SET pending = ?, updated_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""", (int(pending), datetime.utcnow().isoformat(), run_id))
|
||||||
|
|
||||||
@ -322,4 +322,42 @@ Button.game_button {
|
|||||||
min-height: 3;
|
min-height: 3;
|
||||||
margin: 1 0;
|
margin: 1 0;
|
||||||
align: center middle;
|
align: center middle;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#config_scroll_container {
|
||||||
|
height: 70vh;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 1;
|
||||||
|
border-top: solid {{BORDER}};
|
||||||
|
}
|
||||||
|
|
||||||
|
.top_aligned_layout {
|
||||||
|
layout: vertical;
|
||||||
|
align-horizontal: left;
|
||||||
|
width: 100%;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form_column {
|
||||||
|
layout: vertical;
|
||||||
|
width: 100%;
|
||||||
|
padding: 2;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
Input, Select {
|
||||||
|
width: 100%;
|
||||||
|
min-width: 40;
|
||||||
|
}
|
||||||
|
|
||||||
|
.section_label {
|
||||||
|
text-style: bold;
|
||||||
|
padding-top: 1;
|
||||||
|
padding-bottom: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.input_row {
|
||||||
|
layout: horizontal;
|
||||||
|
align-vertical: middle;
|
||||||
|
padding: 0 1;
|
||||||
|
}
|
||||||
|
|||||||
@ -322,4 +322,42 @@ Button.game_button {
|
|||||||
min-height: 3;
|
min-height: 3;
|
||||||
margin: 1 0;
|
margin: 1 0;
|
||||||
align: center middle;
|
align: center middle;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#config_scroll_container {
|
||||||
|
height: 70vh;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 1;
|
||||||
|
border-top: solid #3a9bed;
|
||||||
|
}
|
||||||
|
|
||||||
|
.top_aligned_layout {
|
||||||
|
layout: vertical;
|
||||||
|
align-horizontal: left;
|
||||||
|
width: 100%;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
.form_column {
|
||||||
|
layout: vertical;
|
||||||
|
width: 100%;
|
||||||
|
padding: 2;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
Input, Select {
|
||||||
|
width: 100%;
|
||||||
|
min-width: 40;
|
||||||
|
}
|
||||||
|
|
||||||
|
.section_label {
|
||||||
|
text-style: bold;
|
||||||
|
padding-top: 1;
|
||||||
|
padding-bottom: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.input_row {
|
||||||
|
layout: horizontal;
|
||||||
|
align-vertical: middle;
|
||||||
|
padding: 0 1;
|
||||||
|
}
|
||||||
|
|||||||
@ -62,6 +62,7 @@ class AgentHomeView(Screen):
|
|||||||
await self.app.push_screen(
|
await self.app.push_screen(
|
||||||
ModelSelectView(
|
ModelSelectView(
|
||||||
agent_metadata=self.agent_metadata,
|
agent_metadata=self.agent_metadata,
|
||||||
mode=action
|
mode=action,
|
||||||
|
run=self.run_metadata
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -42,11 +42,11 @@ class ModelSelectView(Screen):
|
|||||||
("r", "refresh_models", "Refresh"),
|
("r", "refresh_models", "Refresh"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"]):
|
def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"], run: dict):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.agent_metadata = agent_metadata
|
self.agent_metadata = agent_metadata
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.selected_model = None
|
self.run_metadata = run
|
||||||
|
|
||||||
def compose(self):
|
def compose(self):
|
||||||
header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_model.txt"
|
header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_model.txt"
|
||||||
@ -132,7 +132,7 @@ class ModelSelectView(Screen):
|
|||||||
log_with_caller("info", f"Model confirmed: {self.selected_model['name']} for mode={self.mode}")
|
log_with_caller("info", f"Model confirmed: {self.selected_model['name']} for mode={self.mode}")
|
||||||
match self.mode:
|
match self.mode:
|
||||||
case "train":
|
case "train":
|
||||||
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=self.selected_model))
|
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None, run=self.run_metadata))
|
||||||
case "eval":
|
case "eval":
|
||||||
await self.app.push_screen(EvaluationView(agent=self.agent_metadata, model=self.selected_model))
|
await self.app.push_screen(EvaluationView(agent=self.agent_metadata, model=self.selected_model))
|
||||||
case "submit":
|
case "submit":
|
||||||
@ -140,7 +140,10 @@ class ModelSelectView(Screen):
|
|||||||
|
|
||||||
elif event.button.id == "create_model_btn":
|
elif event.button.id == "create_model_btn":
|
||||||
log_with_caller("info", f"Creating new model for agent: {self.agent_metadata['name']}")
|
log_with_caller("info", f"Creating new model for agent: {self.agent_metadata['name']}")
|
||||||
await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None)) # New model starts from training
|
await self.app.push_screen(
|
||||||
|
TrainingView(agent=self.agent_metadata, model=None, run=self.run_metadata)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def action_refresh_models(self):
|
async def action_refresh_models(self):
|
||||||
await self.refresh_model_list()
|
await self.refresh_model_list()
|
||||||
|
|||||||
@ -1,106 +1,152 @@
|
|||||||
from textual.screen import Screen
|
from textual.screen import Screen
|
||||||
from textual.widgets import Static, Input, Button
|
from textual.widgets import Static, Input, Button, Checkbox, Label, Select
|
||||||
from textual.containers import Vertical, Horizontal
|
from textual.containers import Vertical, VerticalScroll, Horizontal, Grid
|
||||||
from agentm.theme.palette import get_theme
|
from agentm.components.footer import AgentMFooter
|
||||||
from agentm.utils.logger import log_with_caller
|
from agentm.utils.logger import log_with_caller
|
||||||
from agentm.logic.db_functions import insert_model, update_run_pending
|
from agentm.theme.palette import get_theme
|
||||||
from datetime import datetime
|
from pathlib import Path
|
||||||
|
|
||||||
palette = get_theme()
|
palette = get_theme()
|
||||||
|
HEADER_PATH = Path("agentm/assets/headers/training_setup.txt")
|
||||||
|
|
||||||
class TrainingView(Screen):
|
class TrainingView(Screen):
|
||||||
BINDINGS = [
|
BINDINGS = [("escape", "app.pop_screen", "Back")]
|
||||||
("escape", "app.pop_screen", "Back")
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, agent: dict, model: dict | None, run: dict):
|
def __init__(self, agent: dict, model: dict | None, run: dict):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.agent_metadata = agent
|
self.agent_metadata = agent
|
||||||
self.model_metadata = model
|
self.model_metadata = model
|
||||||
self.run_metadata = run
|
self.run_metadata = run
|
||||||
self.is_pending_run = run.get("pending", True)
|
|
||||||
|
log_with_caller("info", f"Initialized TrainingView for agent: {agent['name']}")
|
||||||
|
|
||||||
|
self.model_name_input = Input(id="model_name")
|
||||||
|
self.num_env_input = Input(id="num_env")
|
||||||
|
|
||||||
|
self.framework_dropdown = Select(
|
||||||
|
id="framework",
|
||||||
|
options=[("SB3", "SB3"), ("RAY", "RAY"), ("SHEEPRL", "SHEEPRL")],
|
||||||
|
value="SB3"
|
||||||
|
)
|
||||||
|
self.algo_dropdown = Select(id="algo", options=[("PPO", "PPO")], value="PPO")
|
||||||
|
|
||||||
|
self.settings_inputs = {
|
||||||
|
"step_ratio": Input(value="6", id="step_ratio"),
|
||||||
|
"frame_shape": Input(value="[128, 128, 1]", id="frame_shape"),
|
||||||
|
"continue_game": Input(value="0.0", id="continue_game"),
|
||||||
|
"action_space": Input(value="discrete", id="action_space")
|
||||||
|
}
|
||||||
|
|
||||||
|
self.wrapper_checkboxes = {
|
||||||
|
"normalize_reward": Checkbox(label="Normalize Reward", value=True, id="normalize_reward"),
|
||||||
|
"no_attack_buttons_combinations": Checkbox(label="No Attack Combos", value=True, id="no_attack_buttons_combinations"),
|
||||||
|
"add_last_action": Checkbox(label="Add Last Action", value=True, id="add_last_action"),
|
||||||
|
"scale": Checkbox(label="Scale", value=True, id="scale"),
|
||||||
|
"exclude_image_scaling": Checkbox(label="Exclude Image Scaling", value=True, id="exclude_image_scaling"),
|
||||||
|
"role_relative": Checkbox(label="Role Relative", value=True, id="role_relative"),
|
||||||
|
"flatten": Checkbox(label="Flatten", value=True, id="flatten")
|
||||||
|
}
|
||||||
|
|
||||||
|
self.wrapper_inputs = {
|
||||||
|
"stack_frames": Input(value="4", id="stack_frames"),
|
||||||
|
"stack_actions": Input(value="12", id="stack_actions"),
|
||||||
|
"dilation": Input(value="1", id="dilation")
|
||||||
|
}
|
||||||
|
|
||||||
|
self.filter_key_checkboxes = [
|
||||||
|
Checkbox(label=key, value=True, id=f"fk_{key}") for key in ["frame", "action", "stage", "timer"]
|
||||||
|
]
|
||||||
|
|
||||||
|
self.net_arch_input = Input(value="[64, 64]", id="net_arch")
|
||||||
|
|
||||||
|
self.ppo_inputs = {
|
||||||
|
"gamma": Input(value="0.94", id="gamma"),
|
||||||
|
"model_checkpoint": Input(value="0", id="model_checkpoint"),
|
||||||
|
"learning_rate_start": Input(value="0.00025", id="lr_start"),
|
||||||
|
"learning_rate_end": Input(value="0.0000025", id="lr_end"),
|
||||||
|
"clip_range_start": Input(value="0.15", id="clip_start"),
|
||||||
|
"clip_range_end": Input(value="0.025", id="clip_end"),
|
||||||
|
"batch_size": Input(value="256", id="batch_size"),
|
||||||
|
"n_epochs": Input(value="4", id="n_epochs"),
|
||||||
|
"n_steps": Input(value="128", id="n_steps"),
|
||||||
|
"autosave_freq": Input(value="256", id="autosave_freq"),
|
||||||
|
"time_steps": Input(value="512", id="time_steps")
|
||||||
|
}
|
||||||
|
|
||||||
def compose(self):
|
def compose(self):
|
||||||
log_with_caller("info", f"Opening TrainingView for agent='{self.agent_metadata['name']}', run='{self.run_metadata['name']}', pending={self.is_pending_run}")
|
log_with_caller("debug", "Composing TrainingView layout...")
|
||||||
|
|
||||||
yield Static(f"[{palette.ACCENT} bold]Training: {self.agent_metadata['name']}[/]", classes="header")
|
try:
|
||||||
|
header_text = HEADER_PATH.read_text()
|
||||||
|
except FileNotFoundError:
|
||||||
|
log_with_caller("error", f"Missing header file at: {HEADER_PATH}")
|
||||||
|
header_text = "[bold red]Missing Header File"
|
||||||
|
|
||||||
if self.is_pending_run:
|
yield Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header")
|
||||||
yield self.render_pending_setup()
|
|
||||||
else:
|
|
||||||
yield self.render_resume_options()
|
|
||||||
|
|
||||||
def render_pending_setup(self):
|
# Build scrollable form content
|
||||||
yield Static("[b]Initial Model Setup[/b]", classes="subheader")
|
yield VerticalScroll(
|
||||||
|
Vertical(
|
||||||
|
|
||||||
self.name_input = Input(placeholder="Model Name", id="model_name")
|
# Framework & Algorithm
|
||||||
self.steps_input = Input(placeholder="Total Training Steps", id="total_steps")
|
Static("Framework & Algorithm", classes="section_label"),
|
||||||
self.lr_input = Input(placeholder="Initial Learning Rate", id="learning_rate")
|
Horizontal(
|
||||||
self.clip_input = Input(placeholder="Initial Clip Range", id="clip_range")
|
Vertical(Label("Framework"), self.framework_dropdown),
|
||||||
self.notes_input = Input(placeholder="Notes", id="notes")
|
Vertical(Label("Algorithm"), self.algo_dropdown),
|
||||||
|
classes="input_row"
|
||||||
|
),
|
||||||
|
|
||||||
self.confirm_button = Button("✅ Create & Start Training", id="confirm_model_btn", classes="confirm_button")
|
# Model name / num_env
|
||||||
|
Static("Model Configuration", classes="section_label"),
|
||||||
|
Horizontal(
|
||||||
|
Vertical(Label("Model Name"), self.model_name_input),
|
||||||
|
Vertical(Label("Num Envs"), self.num_env_input),
|
||||||
|
classes="input_row"
|
||||||
|
),
|
||||||
|
|
||||||
return Vertical(
|
# General Settings
|
||||||
self.name_input,
|
Static("General Settings", classes="section_label"),
|
||||||
self.steps_input,
|
*[
|
||||||
self.lr_input,
|
Horizontal(Vertical(Label(k)), v, classes="input_row")
|
||||||
self.clip_input,
|
for k, v in self.settings_inputs.items()
|
||||||
self.notes_input,
|
],
|
||||||
self.confirm_button,
|
|
||||||
classes="centered_layout"
|
# Wrapper Checkboxes
|
||||||
|
Static("Wrapper Flags", classes="section_label"),
|
||||||
|
*[
|
||||||
|
Horizontal(cb, classes="input_row")
|
||||||
|
for cb in self.wrapper_checkboxes.values()
|
||||||
|
],
|
||||||
|
|
||||||
|
# Wrapper Inputs
|
||||||
|
*[
|
||||||
|
Horizontal(Vertical(Label(k.replace("_", " ").title())), v, classes="input_row")
|
||||||
|
for k, v in self.wrapper_inputs.items()
|
||||||
|
],
|
||||||
|
|
||||||
|
# Filter Keys
|
||||||
|
Static("Filter Keys", classes="section_label"),
|
||||||
|
Grid(*self.filter_key_checkboxes, id="filter_grid"),
|
||||||
|
|
||||||
|
# Policy kwargs
|
||||||
|
Static("Policy Architecture", classes="section_label"),
|
||||||
|
Horizontal(Label("Net Arch"), self.net_arch_input, classes="input_row"),
|
||||||
|
|
||||||
|
# PPO Settings
|
||||||
|
Static("PPO Settings", classes="section_label"),
|
||||||
|
*[
|
||||||
|
Horizontal(Label(k.replace("_", " ").title()), v, classes="input_row")
|
||||||
|
for k, v in self.ppo_inputs.items()
|
||||||
|
],
|
||||||
|
|
||||||
|
id="form_layout",
|
||||||
|
classes="form_column"
|
||||||
|
),
|
||||||
|
id="config_scroll_container"
|
||||||
)
|
)
|
||||||
|
|
||||||
def render_resume_options(self):
|
# Final button and footer
|
||||||
model = self.model_metadata
|
yield Button("🔍 Review Config", id="review_config_btn", classes="confirm_button")
|
||||||
from rich.table import Table
|
yield AgentMFooter(compact=True)
|
||||||
from rich.panel import Panel
|
|
||||||
|
|
||||||
table = Table.grid(padding=(0, 1))
|
log_with_caller("debug", "Finished composing TrainingView")
|
||||||
table.add_column("Key", style="bold underline")
|
|
||||||
table.add_column("Value", style=palette.ACCENT, overflow="fold")
|
|
||||||
|
|
||||||
table.add_row("Name", model["name"])
|
|
||||||
table.add_row("Steps", f"{model['total_steps_completed']} / {model['total_steps_planned']}")
|
|
||||||
table.add_row("Reward", str(model.get("average_reward", "—")))
|
|
||||||
table.add_row("Learning Rate", str(model.get("current_learning_rate", "—")))
|
|
||||||
table.add_row("Clip Range", str(model.get("current_clip_range", "—")))
|
|
||||||
table.add_row("Created", model.get("created_at", "—"))
|
|
||||||
|
|
||||||
return Vertical(
|
|
||||||
Static(Panel(table, title="Model Info", border_style=palette.BORDER), classes="agent_info_box"),
|
|
||||||
Button("⏯️ Resume Training", id="resume_training_btn", classes="confirm_button"),
|
|
||||||
classes="centered_layout"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
if event.button.id == "confirm_model_btn":
|
|
||||||
try:
|
|
||||||
name = self.name_input.value.strip()
|
|
||||||
total_steps = int(self.steps_input.value)
|
|
||||||
learning_rate = float(self.lr_input.value)
|
|
||||||
clip_range = float(self.clip_input.value)
|
|
||||||
notes = self.notes_input.value.strip()
|
|
||||||
|
|
||||||
log_with_caller("info", f"Creating new model: {name} for agent_id={self.agent_metadata['id']}")
|
|
||||||
|
|
||||||
insert_model(
|
|
||||||
agent_id=self.agent_metadata["id"],
|
|
||||||
name=name,
|
|
||||||
total_steps_planned=total_steps,
|
|
||||||
current_learning_rate=learning_rate,
|
|
||||||
current_clip_range=clip_range,
|
|
||||||
notes=notes
|
|
||||||
)
|
|
||||||
|
|
||||||
update_run_pending(self.run_metadata["id"], False)
|
|
||||||
log_with_caller("info", f"Model '{name}' created and run marked as not pending")
|
|
||||||
|
|
||||||
await self.app.pop_screen()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_with_caller("error", f"Failed to create model: {e}")
|
|
||||||
|
|
||||||
elif event.button.id == "resume_training_btn":
|
|
||||||
log_with_caller("info", f"Resuming training for model: {self.model_metadata['name']}")
|
|
||||||
await self.app.pop_screen()
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user