From fb6e2fcd65628b7fca3e24a8d4169062dbf2316e Mon Sep 17 00:00:00 2001 From: mscrnt Date: Tue, 27 May 2025 06:19:16 -0700 Subject: [PATCH] styling on training added. --- agentm/assets/headers/training_setup.txt | 10 + agentm/data/game_data/doapp.yaml | 10 + agentm/data/game_data/general.yaml | 41 ++++ agentm/data/game_data/kof98umh.yaml | 10 + agentm/data/game_data/mvsc.yaml | 10 + agentm/data/game_data/samsh5sp.yaml | 10 + agentm/data/game_data/sfiii3n.yaml | 88 ++++++++ agentm/data/game_data/soulclbr.yaml | 10 + agentm/data/game_data/tektagt.yaml | 10 + agentm/data/game_data/umk3.yaml | 10 + agentm/data/game_data/xmvsf.yaml | 10 + .../logic/{config_builder.py => __init__.py} | 0 agentm/logic/config_manager.py | 82 +++++++ agentm/logic/db_functions.py | 16 ++ agentm/theme/styles.base.tcss | 40 +++- agentm/theme/styles.tcss | 40 +++- agentm/views/agent_home.py | 5 +- agentm/views/model_select.py | 11 +- agentm/views/training.py | 212 +++++++++++------- 19 files changed, 534 insertions(+), 91 deletions(-) create mode 100644 agentm/assets/headers/training_setup.txt create mode 100644 agentm/data/game_data/doapp.yaml create mode 100644 agentm/data/game_data/general.yaml create mode 100644 agentm/data/game_data/kof98umh.yaml create mode 100644 agentm/data/game_data/mvsc.yaml create mode 100644 agentm/data/game_data/samsh5sp.yaml create mode 100644 agentm/data/game_data/sfiii3n.yaml create mode 100644 agentm/data/game_data/soulclbr.yaml create mode 100644 agentm/data/game_data/tektagt.yaml create mode 100644 agentm/data/game_data/umk3.yaml create mode 100644 agentm/data/game_data/xmvsf.yaml rename agentm/logic/{config_builder.py => __init__.py} (100%) create mode 100644 agentm/logic/config_manager.py diff --git a/agentm/assets/headers/training_setup.txt b/agentm/assets/headers/training_setup.txt new file mode 100644 index 0000000..0ecbe87 --- /dev/null +++ b/agentm/assets/headers/training_setup.txt @@ -0,0 +1,10 @@ + + +████████╗██████╗ █████╗ ██╗███╗ ██╗██╗███╗ ██╗ ██████╗ ███████╗███████╗████████╗██╗ ██╗██████╗ +╚══██╔══╝██╔══██╗██╔══██╗██║████╗ ██║██║████╗ ██║██╔════╝ ██╔════╝██╔════╝╚══██╔══╝██║ ██║██╔══██╗ + ██║ ██████╔╝███████║██║██╔██╗ ██║██║██╔██╗ ██║██║ ███╗ ███████╗█████╗ ██║ ██║ ██║██████╔╝ + ██║ ██╔══██╗██╔══██║██║██║╚██╗██║██║██║╚██╗██║██║ ██║ ╚════██║██╔══╝ ██║ ██║ ██║██╔═══╝ + ██║ ██║ ██║██║ ██║██║██║ ╚████║██║██║ ╚████║╚██████╔╝ ███████║███████╗ ██║ ╚██████╔╝██║ + ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚══════╝ ╚═╝ ╚═════╝ ╚═╝ + + diff --git a/agentm/data/game_data/doapp.yaml b/agentm/data/game_data/doapp.yaml new file mode 100644 index 0000000..f672620 --- /dev/null +++ b/agentm/data/game_data/doapp.yaml @@ -0,0 +1,10 @@ + + +██████╗ ███████╗ █████╗ ██████╗ ██████╗ ██████╗ █████╗ ██╗ ██╗██╗ ██╗███████╗ +██╔══██╗██╔════╝██╔══██╗██╔══██╗ ██╔═══██╗██╔══██╗ ██╔══██╗██║ ██║██║ ██║██╔════╝ +██║ ██║█████╗ ███████║██║ ██║ ██║ ██║██████╔╝ ███████║██║ ██║██║ ██║█████╗ +██║ ██║██╔══╝ ██╔══██║██║ ██║ ██║ ██║██╔══██╗ ██╔══██║██║ ██║╚██╗ ██╔╝██╔══╝ +██████╔╝███████╗██║ ██║██████╔╝ ╚██████╔╝██║ ██║ ██║ ██║███████╗██║ ╚████╔╝ ███████╗ +╚═════╝ ╚══════╝╚═╝ ╚═╝╚═════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚══════╝ + + diff --git a/agentm/data/game_data/general.yaml b/agentm/data/game_data/general.yaml new file mode 100644 index 0000000..443229c --- /dev/null +++ b/agentm/data/game_data/general.yaml @@ -0,0 +1,41 @@ +folders: + parent_dir: "./results/" + model_name: "default_model" + +settings: + step_ratio: 6 + frame_shape: [128, 128, 1] + continue_game: 0.0 + action_space: "discrete" + +wrappers_settings: + wrappers: + normalize_reward: true + no_attack_buttons_combinations: true + stack_frames: 4 + dilation: 1 + add_last_action: true + stack_actions: 12 + scale: true + exclude_image_scaling: true + role_relative: true + flatten: true + filter_keys: + - frame + - action + - stage + - timer + +policy_kwargs: + net_arch: [64, 64] + +ppo_settings: + gamma: 0.94 + model_checkpoint: "0" + learning_rate: [2.5e-4, 2.5e-6] + clip_range: [0.15, 0.025] + batch_size: 256 + n_epochs: 4 + n_steps: 128 + autosave_freq: 256 + time_steps: 512 diff --git a/agentm/data/game_data/kof98umh.yaml b/agentm/data/game_data/kof98umh.yaml new file mode 100644 index 0000000..3ffb8d8 --- /dev/null +++ b/agentm/data/game_data/kof98umh.yaml @@ -0,0 +1,10 @@ + + +████████╗██╗ ██╗███████╗ ██╗ ██╗██╗███╗ ██╗ ██████╗ ██████╗ ███████╗ ███████╗██╗ ██████╗ ██╗ ██╗████████╗███████╗██████╗ ███████╗ +╚══██╔══╝██║ ██║██╔════╝ ██║ ██╔╝██║████╗ ██║██╔════╝ ██╔═══██╗██╔════╝ ██╔════╝██║██╔════╝ ██║ ██║╚══██╔══╝██╔════╝██╔══██╗██╔════╝ + ██║ ███████║█████╗ █████╔╝ ██║██╔██╗ ██║██║ ███╗ ██║ ██║█████╗ █████╗ ██║██║ ███╗███████║ ██║ █████╗ ██████╔╝███████╗ + ██║ ██╔══██║██╔══╝ ██╔═██╗ ██║██║╚██╗██║██║ ██║ ██║ ██║██╔══╝ ██╔══╝ ██║██║ ██║██╔══██║ ██║ ██╔══╝ ██╔══██╗╚════██║ + ██║ ██║ ██║███████╗ ██║ ██╗██║██║ ╚████║╚██████╔╝ ╚██████╔╝██║ ██║ ██║╚██████╔╝██║ ██║ ██║ ███████╗██║ ██║███████║ + ╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝╚══════╝ + + diff --git a/agentm/data/game_data/mvsc.yaml b/agentm/data/game_data/mvsc.yaml new file mode 100644 index 0000000..ccafabe --- /dev/null +++ b/agentm/data/game_data/mvsc.yaml @@ -0,0 +1,10 @@ + + +███╗ ███╗ █████╗ ██████╗ ██╗ ██╗███████╗██╗ ██╗ ██╗███████╗ ██████╗ █████╗ ██████╗ ██████╗ ██████╗ ███╗ ███╗ +████╗ ████║██╔══██╗██╔══██╗██║ ██║██╔════╝██║ ██║ ██║██╔════╝ ██╔════╝██╔══██╗██╔══██╗██╔════╝██╔═══██╗████╗ ████║ +██╔████╔██║███████║██████╔╝██║ ██║█████╗ ██║ ██║ ██║███████╗ ██║ ███████║██████╔╝██║ ██║ ██║██╔████╔██║ +██║╚██╔╝██║██╔══██║██╔══██╗╚██╗ ██╔╝██╔══╝ ██║ ╚██╗ ██╔╝╚════██║ ██║ ██╔══██║██╔═══╝ ██║ ██║ ██║██║╚██╔╝██║ +██║ ╚═╝ ██║██║ ██║██║ ██║ ╚████╔╝ ███████╗███████╗ ╚████╔╝ ███████║██╗ ╚██████╗██║ ██║██║ ╚██████╗╚██████╔╝██║ ╚═╝ ██║ +╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═══╝ ╚══════╝╚══════╝ ╚═══╝ ╚══════╝╚═╝ ╚═════╝╚═╝ ╚═╝╚═╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝ + + diff --git a/agentm/data/game_data/samsh5sp.yaml b/agentm/data/game_data/samsh5sp.yaml new file mode 100644 index 0000000..e5477ee --- /dev/null +++ b/agentm/data/game_data/samsh5sp.yaml @@ -0,0 +1,10 @@ + + +███████╗ █████╗ ███╗ ███╗██╗ ██╗██████╗ █████╗ ██╗ ███████╗██╗ ██╗ ██████╗ ██████╗ ██████╗ ██╗ ██╗███╗ ██╗ ██╗ ██╗ +██╔════╝██╔══██╗████╗ ████║██║ ██║██╔══██╗██╔══██╗██║ ██╔════╝██║ ██║██╔═══██╗██╔══██╗██╔═══██╗██║ ██║████╗ ██║ ██║ ██║ +███████╗███████║██╔████╔██║██║ ██║██████╔╝███████║██║ ███████╗███████║██║ ██║██║ ██║██║ ██║██║ █╗ ██║██╔██╗ ██║ ██║ ██║ +╚════██║██╔══██║██║╚██╔╝██║██║ ██║██╔══██╗██╔══██║██║ ╚════██║██╔══██║██║ ██║██║ ██║██║ ██║██║███╗██║██║╚██╗██║ ╚██╗ ██╔╝ +███████║██║ ██║██║ ╚═╝ ██║╚██████╔╝██║ ██║██║ ██║██║ ███████║██║ ██║╚██████╔╝██████╔╝╚██████╔╝╚███╔███╔╝██║ ╚████║ ╚████╔╝ +╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═════╝ ╚══╝╚══╝ ╚═╝ ╚═══╝ ╚═══╝ + + diff --git a/agentm/data/game_data/sfiii3n.yaml b/agentm/data/game_data/sfiii3n.yaml new file mode 100644 index 0000000..1c48453 --- /dev/null +++ b/agentm/data/game_data/sfiii3n.yaml @@ -0,0 +1,88 @@ +character_mode: single + +character_index: + 0: Alex + 1: Twelve + 2: Hugo + 3: Sean + 4: Makoto + 5: Elena + 6: Ibuki + 7: Chun-Li + 8: Dudley + 9: Necro + 10: Q + 11: Oro + 12: Urien + 13: Remy + 14: Ryu + 15: Gouki + 16: Yun + 17: Yang + 18: Ken + 19: Gill + +stun_bar_lengths: + Alex: 72 + Dudley: 72 + Elena: 64 + Gill: 72 + Gouki: 52 + Hugo: 72 + Ibuki: 64 + Ken: 64 + Makoto: 64 + Necro: 64 + Oro: 72 + Q: 72 + Remy: 52 + Ryu: 64 + Sean: 64 + Twelve: 64 + Urien: 64 + Yang: 64 + Yun: 64 + Chun-Li: 64 + +stamina_ratings: + Alex: 1200 + Chun-Li: 1100 + Dudley: 1150 + Elena: 1100 + Gill: 1250 + Gouki: 950 + Hugo: 1300 + Ibuki: 1000 + Ken: 1150 + Makoto: 1150 + Necro: 1070 + Oro: 1100 + Q: 1200 + Remy: 1070 + Ryu: 1150 + Sean: 1120 + Twelve: 1050 + Urien: 1220 + Yang: 1020 + Yun: 1020 + +wrappers_settings: + filter_keys: + - own_health + - opp_health + - own_side + - opp_side + - own_character + - opp_character + - own_stun_bar + - opp_stun_bar + - own_stunned + - opp_stunned + - own_super_bar + - opp_super_bar + - own_super_type + - opp_super_type + - own_super_count + - opp_super_count + - own_super_max_count + - opp_super_max_count diff --git a/agentm/data/game_data/soulclbr.yaml b/agentm/data/game_data/soulclbr.yaml new file mode 100644 index 0000000..93108a4 --- /dev/null +++ b/agentm/data/game_data/soulclbr.yaml @@ -0,0 +1,10 @@ + + +███████╗ ██████╗ ██╗ ██╗██╗ ██████╗ █████╗ ██╗ ██╗██████╗ ██╗ ██╗██████╗ +██╔════╝██╔═══██╗██║ ██║██║ ██╔════╝██╔══██╗██║ ██║██╔══██╗██║ ██║██╔══██╗ +███████╗██║ ██║██║ ██║██║ ██║ ███████║██║ ██║██████╔╝██║ ██║██████╔╝ +╚════██║██║ ██║██║ ██║██║ ██║ ██╔══██║██║ ██║██╔══██╗██║ ██║██╔══██╗ +███████║╚██████╔╝╚██████╔╝███████╗ ╚██████╗██║ ██║███████╗██║██████╔╝╚██████╔╝██║ ██║ +╚══════╝ ╚═════╝ ╚═════╝ ╚══════╝ ╚═════╝╚═╝ ╚═╝╚══════╝╚═╝╚═════╝ ╚═════╝ ╚═╝ ╚═╝ + + diff --git a/agentm/data/game_data/tektagt.yaml b/agentm/data/game_data/tektagt.yaml new file mode 100644 index 0000000..d2ddf65 --- /dev/null +++ b/agentm/data/game_data/tektagt.yaml @@ -0,0 +1,10 @@ + + +████████╗███████╗██╗ ██╗██╗ ██╗███████╗███╗ ██╗ ████████╗ █████╗ ██████╗ ████████╗ ██████╗ ██╗ ██╗██████╗ ███╗ ██╗ █████╗ ███╗ ███╗███████╗███╗ ██╗████████╗ +╚══██╔══╝██╔════╝██║ ██╔╝██║ ██╔╝██╔════╝████╗ ██║ ╚══██╔══╝██╔══██╗██╔════╝ ╚══██╔══╝██╔═══██╗██║ ██║██╔══██╗████╗ ██║██╔══██╗████╗ ████║██╔════╝████╗ ██║╚══██╔══╝ + ██║ █████╗ █████╔╝ █████╔╝ █████╗ ██╔██╗ ██║ ██║ ███████║██║ ███╗ ██║ ██║ ██║██║ ██║██████╔╝██╔██╗ ██║███████║██╔████╔██║█████╗ ██╔██╗ ██║ ██║ + ██║ ██╔══╝ ██╔═██╗ ██╔═██╗ ██╔══╝ ██║╚██╗██║ ██║ ██╔══██║██║ ██║ ██║ ██║ ██║██║ ██║██╔══██╗██║╚██╗██║██╔══██║██║╚██╔╝██║██╔══╝ ██║╚██╗██║ ██║ + ██║ ███████╗██║ ██╗██║ ██╗███████╗██║ ╚████║ ██║ ██║ ██║╚██████╔╝ ██║ ╚██████╔╝╚██████╔╝██║ ██║██║ ╚████║██║ ██║██║ ╚═╝ ██║███████╗██║ ╚████║ ██║ + ╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═══╝╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚═╝ + + diff --git a/agentm/data/game_data/umk3.yaml b/agentm/data/game_data/umk3.yaml new file mode 100644 index 0000000..8ade7d5 --- /dev/null +++ b/agentm/data/game_data/umk3.yaml @@ -0,0 +1,10 @@ + + +███╗ ███╗ ██████╗ ██████╗ ████████╗ █████╗ ██╗ ██╗ ██╗ ██████╗ ███╗ ███╗██████╗ █████╗ ████████╗ ██████╗ +████╗ ████║██╔═══██╗██╔══██╗╚══██╔══╝██╔══██╗██║ ██║ ██╔╝██╔═══██╗████╗ ████║██╔══██╗██╔══██╗╚══██╔══╝ ╚════██╗ +██╔████╔██║██║ ██║██████╔╝ ██║ ███████║██║ █████╔╝ ██║ ██║██╔████╔██║██████╔╝███████║ ██║ █████╔╝ +██║╚██╔╝██║██║ ██║██╔══██╗ ██║ ██╔══██║██║ ██╔═██╗ ██║ ██║██║╚██╔╝██║██╔══██╗██╔══██║ ██║ ╚═══██╗ +██║ ╚═╝ ██║╚██████╔╝██║ ██║ ██║ ██║ ██║███████╗ ██║ ██╗╚██████╔╝██║ ╚═╝ ██║██████╔╝██║ ██║ ██║ ██████╔╝ +╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ + + diff --git a/agentm/data/game_data/xmvsf.yaml b/agentm/data/game_data/xmvsf.yaml new file mode 100644 index 0000000..d18487a --- /dev/null +++ b/agentm/data/game_data/xmvsf.yaml @@ -0,0 +1,10 @@ + + +██╗ ██╗ ███╗ ███╗███████╗███╗ ██╗ ██╗ ██╗███████╗ ███████╗████████╗██████╗ ███████╗███████╗████████╗ ███████╗██╗ ██████╗ ██╗ ██╗████████╗███████╗██████╗ +╚██╗██╔╝ ████╗ ████║██╔════╝████╗ ██║ ██║ ██║██╔════╝ ██╔════╝╚══██╔══╝██╔══██╗██╔════╝██╔════╝╚══██╔══╝ ██╔════╝██║██╔════╝ ██║ ██║╚══██╔══╝██╔════╝██╔══██╗ + ╚███╔╝█████╗██╔████╔██║█████╗ ██╔██╗ ██║ ██║ ██║███████╗ ███████╗ ██║ ██████╔╝█████╗ █████╗ ██║ █████╗ ██║██║ ███╗███████║ ██║ █████╗ ██████╔╝ + ██╔██╗╚════╝██║╚██╔╝██║██╔══╝ ██║╚██╗██║ ╚██╗ ██╔╝╚════██║ ╚════██║ ██║ ██╔══██╗██╔══╝ ██╔══╝ ██║ ██╔══╝ ██║██║ ██║██╔══██║ ██║ ██╔══╝ ██╔══██╗ +██╔╝ ██╗ ██║ ╚═╝ ██║███████╗██║ ╚████║ ╚████╔╝ ███████║ ███████║ ██║ ██║ ██║███████╗███████╗ ██║ ██║ ██║╚██████╔╝██║ ██║ ██║ ███████╗██║ ██║ +╚═╝ ╚═╝ ╚═╝ ╚═╝╚══════╝╚═╝ ╚═══╝ ╚═══╝ ╚══════╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚══════╝╚══════╝ ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝ ╚═╝ ╚══════╝╚═╝ ╚═╝ + + diff --git a/agentm/logic/config_builder.py b/agentm/logic/__init__.py similarity index 100% rename from agentm/logic/config_builder.py rename to agentm/logic/__init__.py diff --git a/agentm/logic/config_manager.py b/agentm/logic/config_manager.py new file mode 100644 index 0000000..e6f478a --- /dev/null +++ b/agentm/logic/config_manager.py @@ -0,0 +1,82 @@ +import yaml +from pathlib import Path +from typing import Optional + + +class ConfigManager: + def __init__(self, game_id: str, model_name: str = "default_model"): + self.game_id = game_id + self.model_name = model_name + self.config_base_path = Path("agentm/data/game_data") + + self.general = self._load_yaml("general.yaml") + self.game_specific = self._load_yaml(f"{self.game_id}.yaml") + + self.config = self._build_merged_config() + + def _load_yaml(self, filename: str) -> dict: + path = self.config_base_path / filename + if not path.exists(): + return {} + with open(path, "r") as f: + return yaml.safe_load(f) + + def _deep_merge(self, source: dict, overrides: dict) -> dict: + for key, value in overrides.items(): + if key in source: + if isinstance(source[key], dict) and isinstance(value, dict): + self._deep_merge(source[key], value) + elif isinstance(source[key], list) and isinstance(value, list): + source[key] = list(dict.fromkeys(source[key] + value)) + else: + source[key] = value + else: + source[key] = value + return source + + def _build_merged_config(self) -> dict: + folders = self.general.get("folders", {}).copy() + folders["model_name"] = self.model_name + + settings = self.general.get("settings", {}).copy() + settings["game_id"] = self.game_id + settings = self._deep_merge(settings, self.game_specific.get("settings", {})) + if "frame_shape" in settings and isinstance(settings["frame_shape"], list): + settings["frame_shape"] = tuple(settings["frame_shape"]) + + wrappers_settings = self._deep_merge( + self.general.get("wrappers_settings", {}).copy(), + self.game_specific.get("wrappers_settings", {}) + ) + + policy_kwargs = self._deep_merge( + self.general.get("policy_kwargs", {}).copy(), + self.game_specific.get("policy_kwargs", {}) + ) + + ppo_settings = self._deep_merge( + self.general.get("ppo_settings", {}).copy(), + self.game_specific.get("ppo_settings", {}) + ) + + return { + "folders": folders, + "settings": settings, + "wrappers_settings": wrappers_settings, + "policy_kwargs": policy_kwargs, + "ppo_settings": ppo_settings, + } + + def get(self) -> dict: + """Return the full merged config dictionary.""" + return self.config + + def dump_to_file(self, out_path: Path): + """Write the config to a YAML file.""" + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: + yaml.dump(self.config, f, sort_keys=False) + + def dump_to_string(self) -> str: + """Return the YAML string representation of the config.""" + return yaml.dump(self.config, sort_keys=False) diff --git a/agentm/logic/db_functions.py b/agentm/logic/db_functions.py index 0d1b1c1..0b0e4fd 100644 --- a/agentm/logic/db_functions.py +++ b/agentm/logic/db_functions.py @@ -494,3 +494,19 @@ def update_run( f"UPDATE runs SET {', '.join(fields)} WHERE id = ?", tuple(values) ) + +def update_run_pending(run_id: int, pending: bool) -> None: + """ + Update the pending status of a run. + + Args: + run_id: The ID of the run to update. + pending: New pending status (True/False). + """ + with get_db_conn() as conn: + conn.execute(""" + UPDATE runs + SET pending = ?, updated_at = ? + WHERE id = ? + """, (int(pending), datetime.utcnow().isoformat(), run_id)) + \ No newline at end of file diff --git a/agentm/theme/styles.base.tcss b/agentm/theme/styles.base.tcss index 4389a7c..a8ac9cc 100644 --- a/agentm/theme/styles.base.tcss +++ b/agentm/theme/styles.base.tcss @@ -322,4 +322,42 @@ Button.game_button { min-height: 3; margin: 1 0; align: center middle; -} \ No newline at end of file +} + +#config_scroll_container { + height: 70vh; + overflow-y: auto; + padding: 1; + border-top: solid {{BORDER}}; +} + +.top_aligned_layout { + layout: vertical; + align-horizontal: left; + width: 100%; + padding: 0 2; +} + +.form_column { + layout: vertical; + width: 100%; + padding: 2; + height: auto; +} + +Input, Select { + width: 100%; + min-width: 40; +} + +.section_label { + text-style: bold; + padding-top: 1; + padding-bottom: 0; +} + +.input_row { + layout: horizontal; + align-vertical: middle; + padding: 0 1; +} diff --git a/agentm/theme/styles.tcss b/agentm/theme/styles.tcss index d96b860..4ddd5e3 100644 --- a/agentm/theme/styles.tcss +++ b/agentm/theme/styles.tcss @@ -322,4 +322,42 @@ Button.game_button { min-height: 3; margin: 1 0; align: center middle; -} \ No newline at end of file +} + +#config_scroll_container { + height: 70vh; + overflow-y: auto; + padding: 1; + border-top: solid #3a9bed; +} + +.top_aligned_layout { + layout: vertical; + align-horizontal: left; + width: 100%; + padding: 0 2; +} + +.form_column { + layout: vertical; + width: 100%; + padding: 2; + height: auto; +} + +Input, Select { + width: 100%; + min-width: 40; +} + +.section_label { + text-style: bold; + padding-top: 1; + padding-bottom: 0; +} + +.input_row { + layout: horizontal; + align-vertical: middle; + padding: 0 1; +} diff --git a/agentm/views/agent_home.py b/agentm/views/agent_home.py index 3974a39..05dd7ea 100644 --- a/agentm/views/agent_home.py +++ b/agentm/views/agent_home.py @@ -62,6 +62,7 @@ class AgentHomeView(Screen): await self.app.push_screen( ModelSelectView( agent_metadata=self.agent_metadata, - mode=action + mode=action, + run=self.run_metadata ) - ) + ) \ No newline at end of file diff --git a/agentm/views/model_select.py b/agentm/views/model_select.py index 0e945e2..8a04568 100644 --- a/agentm/views/model_select.py +++ b/agentm/views/model_select.py @@ -42,11 +42,11 @@ class ModelSelectView(Screen): ("r", "refresh_models", "Refresh"), ] - def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"]): + def __init__(self, agent_metadata: dict, mode: Literal["train", "eval", "submit"], run: dict): super().__init__() self.agent_metadata = agent_metadata self.mode = mode - self.selected_model = None + self.run_metadata = run def compose(self): header_path = Path(__file__).parent.parent / "assets" / "headers" / "select_model.txt" @@ -132,7 +132,7 @@ class ModelSelectView(Screen): log_with_caller("info", f"Model confirmed: {self.selected_model['name']} for mode={self.mode}") match self.mode: case "train": - await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=self.selected_model)) + await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None, run=self.run_metadata)) case "eval": await self.app.push_screen(EvaluationView(agent=self.agent_metadata, model=self.selected_model)) case "submit": @@ -140,7 +140,10 @@ class ModelSelectView(Screen): elif event.button.id == "create_model_btn": log_with_caller("info", f"Creating new model for agent: {self.agent_metadata['name']}") - await self.app.push_screen(TrainingView(agent=self.agent_metadata, model=None)) # New model starts from training + await self.app.push_screen( + TrainingView(agent=self.agent_metadata, model=None, run=self.run_metadata) + ) + async def action_refresh_models(self): await self.refresh_model_list() diff --git a/agentm/views/training.py b/agentm/views/training.py index 0fdc834..ebf4845 100644 --- a/agentm/views/training.py +++ b/agentm/views/training.py @@ -1,106 +1,152 @@ from textual.screen import Screen -from textual.widgets import Static, Input, Button -from textual.containers import Vertical, Horizontal -from agentm.theme.palette import get_theme +from textual.widgets import Static, Input, Button, Checkbox, Label, Select +from textual.containers import Vertical, VerticalScroll, Horizontal, Grid +from agentm.components.footer import AgentMFooter from agentm.utils.logger import log_with_caller -from agentm.logic.db_functions import insert_model, update_run_pending -from datetime import datetime +from agentm.theme.palette import get_theme +from pathlib import Path palette = get_theme() +HEADER_PATH = Path("agentm/assets/headers/training_setup.txt") class TrainingView(Screen): - BINDINGS = [ - ("escape", "app.pop_screen", "Back") - ] + BINDINGS = [("escape", "app.pop_screen", "Back")] def __init__(self, agent: dict, model: dict | None, run: dict): super().__init__() self.agent_metadata = agent self.model_metadata = model self.run_metadata = run - self.is_pending_run = run.get("pending", True) + + log_with_caller("info", f"Initialized TrainingView for agent: {agent['name']}") + + self.model_name_input = Input(id="model_name") + self.num_env_input = Input(id="num_env") + + self.framework_dropdown = Select( + id="framework", + options=[("SB3", "SB3"), ("RAY", "RAY"), ("SHEEPRL", "SHEEPRL")], + value="SB3" + ) + self.algo_dropdown = Select(id="algo", options=[("PPO", "PPO")], value="PPO") + + self.settings_inputs = { + "step_ratio": Input(value="6", id="step_ratio"), + "frame_shape": Input(value="[128, 128, 1]", id="frame_shape"), + "continue_game": Input(value="0.0", id="continue_game"), + "action_space": Input(value="discrete", id="action_space") + } + + self.wrapper_checkboxes = { + "normalize_reward": Checkbox(label="Normalize Reward", value=True, id="normalize_reward"), + "no_attack_buttons_combinations": Checkbox(label="No Attack Combos", value=True, id="no_attack_buttons_combinations"), + "add_last_action": Checkbox(label="Add Last Action", value=True, id="add_last_action"), + "scale": Checkbox(label="Scale", value=True, id="scale"), + "exclude_image_scaling": Checkbox(label="Exclude Image Scaling", value=True, id="exclude_image_scaling"), + "role_relative": Checkbox(label="Role Relative", value=True, id="role_relative"), + "flatten": Checkbox(label="Flatten", value=True, id="flatten") + } + + self.wrapper_inputs = { + "stack_frames": Input(value="4", id="stack_frames"), + "stack_actions": Input(value="12", id="stack_actions"), + "dilation": Input(value="1", id="dilation") + } + + self.filter_key_checkboxes = [ + Checkbox(label=key, value=True, id=f"fk_{key}") for key in ["frame", "action", "stage", "timer"] + ] + + self.net_arch_input = Input(value="[64, 64]", id="net_arch") + + self.ppo_inputs = { + "gamma": Input(value="0.94", id="gamma"), + "model_checkpoint": Input(value="0", id="model_checkpoint"), + "learning_rate_start": Input(value="0.00025", id="lr_start"), + "learning_rate_end": Input(value="0.0000025", id="lr_end"), + "clip_range_start": Input(value="0.15", id="clip_start"), + "clip_range_end": Input(value="0.025", id="clip_end"), + "batch_size": Input(value="256", id="batch_size"), + "n_epochs": Input(value="4", id="n_epochs"), + "n_steps": Input(value="128", id="n_steps"), + "autosave_freq": Input(value="256", id="autosave_freq"), + "time_steps": Input(value="512", id="time_steps") + } def compose(self): - log_with_caller("info", f"Opening TrainingView for agent='{self.agent_metadata['name']}', run='{self.run_metadata['name']}', pending={self.is_pending_run}") + log_with_caller("debug", "Composing TrainingView layout...") - yield Static(f"[{palette.ACCENT} bold]Training: {self.agent_metadata['name']}[/]", classes="header") + try: + header_text = HEADER_PATH.read_text() + except FileNotFoundError: + log_with_caller("error", f"Missing header file at: {HEADER_PATH}") + header_text = "[bold red]Missing Header File" - if self.is_pending_run: - yield self.render_pending_setup() - else: - yield self.render_resume_options() + yield Static(f"[{palette.ACCENT}]{header_text}[/{palette.ACCENT}]", classes="header") - def render_pending_setup(self): - yield Static("[b]Initial Model Setup[/b]", classes="subheader") + # Build scrollable form content + yield VerticalScroll( + Vertical( - self.name_input = Input(placeholder="Model Name", id="model_name") - self.steps_input = Input(placeholder="Total Training Steps", id="total_steps") - self.lr_input = Input(placeholder="Initial Learning Rate", id="learning_rate") - self.clip_input = Input(placeholder="Initial Clip Range", id="clip_range") - self.notes_input = Input(placeholder="Notes", id="notes") + # Framework & Algorithm + Static("Framework & Algorithm", classes="section_label"), + Horizontal( + Vertical(Label("Framework"), self.framework_dropdown), + Vertical(Label("Algorithm"), self.algo_dropdown), + classes="input_row" + ), - self.confirm_button = Button("✅ Create & Start Training", id="confirm_model_btn", classes="confirm_button") + # Model name / num_env + Static("Model Configuration", classes="section_label"), + Horizontal( + Vertical(Label("Model Name"), self.model_name_input), + Vertical(Label("Num Envs"), self.num_env_input), + classes="input_row" + ), - return Vertical( - self.name_input, - self.steps_input, - self.lr_input, - self.clip_input, - self.notes_input, - self.confirm_button, - classes="centered_layout" + # General Settings + Static("General Settings", classes="section_label"), + *[ + Horizontal(Vertical(Label(k)), v, classes="input_row") + for k, v in self.settings_inputs.items() + ], + + # Wrapper Checkboxes + Static("Wrapper Flags", classes="section_label"), + *[ + Horizontal(cb, classes="input_row") + for cb in self.wrapper_checkboxes.values() + ], + + # Wrapper Inputs + *[ + Horizontal(Vertical(Label(k.replace("_", " ").title())), v, classes="input_row") + for k, v in self.wrapper_inputs.items() + ], + + # Filter Keys + Static("Filter Keys", classes="section_label"), + Grid(*self.filter_key_checkboxes, id="filter_grid"), + + # Policy kwargs + Static("Policy Architecture", classes="section_label"), + Horizontal(Label("Net Arch"), self.net_arch_input, classes="input_row"), + + # PPO Settings + Static("PPO Settings", classes="section_label"), + *[ + Horizontal(Label(k.replace("_", " ").title()), v, classes="input_row") + for k, v in self.ppo_inputs.items() + ], + + id="form_layout", + classes="form_column" + ), + id="config_scroll_container" ) - def render_resume_options(self): - model = self.model_metadata - from rich.table import Table - from rich.panel import Panel + # Final button and footer + yield Button("🔍 Review Config", id="review_config_btn", classes="confirm_button") + yield AgentMFooter(compact=True) - table = Table.grid(padding=(0, 1)) - table.add_column("Key", style="bold underline") - table.add_column("Value", style=palette.ACCENT, overflow="fold") - - table.add_row("Name", model["name"]) - table.add_row("Steps", f"{model['total_steps_completed']} / {model['total_steps_planned']}") - table.add_row("Reward", str(model.get("average_reward", "—"))) - table.add_row("Learning Rate", str(model.get("current_learning_rate", "—"))) - table.add_row("Clip Range", str(model.get("current_clip_range", "—"))) - table.add_row("Created", model.get("created_at", "—")) - - return Vertical( - Static(Panel(table, title="Model Info", border_style=palette.BORDER), classes="agent_info_box"), - Button("⏯️ Resume Training", id="resume_training_btn", classes="confirm_button"), - classes="centered_layout" - ) - - async def on_button_pressed(self, event: Button.Pressed) -> None: - if event.button.id == "confirm_model_btn": - try: - name = self.name_input.value.strip() - total_steps = int(self.steps_input.value) - learning_rate = float(self.lr_input.value) - clip_range = float(self.clip_input.value) - notes = self.notes_input.value.strip() - - log_with_caller("info", f"Creating new model: {name} for agent_id={self.agent_metadata['id']}") - - insert_model( - agent_id=self.agent_metadata["id"], - name=name, - total_steps_planned=total_steps, - current_learning_rate=learning_rate, - current_clip_range=clip_range, - notes=notes - ) - - update_run_pending(self.run_metadata["id"], False) - log_with_caller("info", f"Model '{name}' created and run marked as not pending") - - await self.app.pop_screen() - - except Exception as e: - log_with_caller("error", f"Failed to create model: {e}") - - elif event.button.id == "resume_training_btn": - log_with_caller("info", f"Resuming training for model: {self.model_metadata['name']}") - await self.app.pop_screen() + log_with_caller("debug", "Finished composing TrainingView")