Skip to content

Commit

Permalink
polish(pu): polish memory env settings and config
Browse files Browse the repository at this point in the history
  • Loading branch information
jiayilee65 committed Apr 11, 2024
1 parent 74356a0 commit 5fa006c
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 50 deletions.
44 changes: 30 additions & 14 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
"max_tokens": 2 * 10, # TODO: horizon:8
# "context_length": 20,
# "context_length_for_recurrent": 20,
"context_length": 6, # TODO
"context_length_for_recurrent": 6,
"context_length": 8, # TODO
"context_length_for_recurrent": 8,
"recurrent_keep_deepth": 100,
"gru_gating": False,
# "gru_gating": True,
Expand All @@ -59,23 +59,44 @@
# "gru_gating": False,
# # "gru_gating": True,

# 'action_shape': 18, # TODO:for multi-task

"device": 'cuda:3',
# 'action_shape': 6, # TODO:for pong qbert
# 'action_shape': 9,# TODO:for mspacman
# 'action_shape': 18,# TODO:for Seaquest boxing Frostbite
'action_shape': 4,# TODO:for breakout

# 'embed_dim':512, # TODO:for atari
# 'embed_dim':256, # TODO:for atari
# 'embed_dim':1024, # TODO:for atari
'embed_dim': 768, # TODO:for atari
'group_size': 8, # NOTE

'attention': 'causal',

"device": 'cuda:3',

'num_layers': 1, # TODO:for atari debug
# 'num_layers': 1, # TODO:for atari debug
# 'num_layers': 2, # TODO:for atari debug
# 'num_layers': 4, # TODO:for atari debug
# 'num_layers': 6, # TODO:for atari debug
# 'num_layers': 12, # TODO:for atari debug
'num_layers': 6, # TODO:for atari debug

# 'num_layers': 8, # TODO:for atari debug
'num_heads': 8,
'embed_dim': 768, # TODO:for atari

# 'num_layers': 12, # TODO:Gpt2 Base
# 'num_heads': 12, # TODO:Gpt2 Base
# 'embed_dim': 768, # TODO:Gpt2 Base

# 'num_layers': 12, # TODO:Gato Medium
# 'num_heads': 12, # TODO:Gato Medium
# 'embed_dim': 1536, # TODO:Gato Medium

# 'num_layers': 8, # TODO:Gato Base
# 'num_heads': 24, # TODO:Gato Base
# 'embed_dim': 768, # TODO:Gato Base

# 'num_layers': 24, # TODO:Gato Large
# 'num_heads': 16, # TODO:Gato Large
# 'embed_dim': 2048, # TODO:Gato Large

'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
Expand All @@ -84,12 +105,7 @@
# 'support_size': 601,
'support_size': 101, # TODO

# 'action_shape': 18, # TODO:for multi-task

# 'action_shape': 6, # TODO:for pong qbert
# 'action_shape': 9,# TODO:for mspacman
# 'action_shape': 18,# TODO:for Seaquest boxing Frostbite
'action_shape': 4,# TODO:for breakout

'max_cache_size':5000,
# 'max_cache_size':50000,
Expand Down
24 changes: 18 additions & 6 deletions lzero/model/gpt_models/cfg_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,28 @@


# 'embed_dim': 64, # TODO:for memory # same as <Transformer shine in RL> paper
'embed_dim': 96, # TODO:for memory # same as <Transformer shine in RL> paper
# 'embed_dim': 96, # TODO:for memory # same as <Transformer shine in RL> paper
'group_size': 8, # NOTE

"device": 'cuda:4',
"device": 'cuda:5',
'attention': 'causal',
# 'num_layers': 1,
'num_layers': 2, # same as <Transformer shine in RL> paper
'num_layers': 4,
# 'num_layers': 2, # same as <Transformer shine in RL> paper
# 'num_layers': 4,
'num_layers': 6,
'num_heads': 8,
# 'embed_dim': 96, # TODO:
'embed_dim': 768, # TODO:Gpt2 Base


# 'num_layers': 8, # TODO:for atari debug
# 'num_heads': 8,
# 'embed_dim': 768, # TODO:for atari

# 'num_layers': 12, # TODO:Gpt2 Base
# 'num_heads': 12, # TODO:Gpt2 Base
# 'embed_dim': 768, # TODO:Gpt2 Base

'gru_gating': False,

'embed_pdrop': 0.1,
Expand All @@ -87,8 +99,8 @@
# 'predict_latent_loss_type': 'mse',

'obs_type': 'image_memory', # 'vector', 'image'
# 'gamma': 1, # 0.5, 0.9, 0.99, 0.999
'gamma': 1.2, # 0.5, 0.9, 0.99, 0.999
'gamma': 1, # 0.5, 0.9, 0.99, 0.999
# 'gamma': 1.2, # 0.5, 0.9, 0.99, 0.999


}
Expand Down
29 changes: 9 additions & 20 deletions zoo/atari/config/atari_xzero_config_stack1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from easydict import EasyDict


import torch
torch.cuda.set_device(3)
# ==== NOTE: 需要设置cfg_atari中的action_shape =====

# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...}
Expand Down Expand Up @@ -47,11 +48,11 @@
n_episode = 8
evaluator_env_num = 3

# model_update_ratio = 0.25
model_update_ratio = 0.5
model_update_ratio = 0.25
# model_update_ratio = 0.5
num_simulations = 50
# max_env_step = int(1e6)
max_env_step = int(5e5)
max_env_step = int(1e6)
# max_env_step = int(5e5)


reanalyze_ratio = 0.
Expand All @@ -68,27 +69,15 @@
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
import torch
torch.cuda.set_device(3)


atari_xzero_config = dict(
# TODO:
# mcts_ctree
# muzero_collector/evaluator: empty_cache
exp_name=f'data_xzero_atari_0408/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_lsd768-nh8_grugating-false_latent-groupkl_conleninit{6}-conlenrecur{6}clear-gamma1_nlayer1-steplosslog_seed0',
# exp_name=f'data_xzero_atari_0407/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_lsd768-nh8_grugating-false_latent-groupkl_conleninit{20}-conlenrecur{20}clear-gamma1.5_nlayer5-steplosslog_seed0',

# exp_name=f'data_xzero_atari_0407/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_lsd768-nlayer1-nh8_grugating-false_latent-groupkl_conleninit{8}-conlenrecur{8}clear-fixposemb_seed0',
# exp_name=f'data_xzero_atari_0407/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_lsd768-nlayer1-nh8_grugating-false_latent-groupkl_conleninit{20}-conlenrecur{20}clear-onlyinitreset_seed0',

# exp_name=f'data_xzero_atari_0404/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_lsd768-nlayer1-nh8_grugating-false_latent-groupkl_conleninit{6}-conlenrecur{6}clear-reset_seed0',
exp_name=f'data_xzero_atari_0410/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_lsd768-nh8_grugating-false_latent-groupkl_conleninit{8}-conlenrecur{8}clear-true_gamma1_nlayer6-steplosslog_seed0',
# exp_name=f'data_xzero_atari_0407/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_grugating-false_latent-groupkl_conleninit{20}-conlenrecur{20}clear-gamma1_lsd1536-nlayer12-nh12_steplosslog_seed0',

# exp_name=f'data_xzero_atari_0404/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_lsd768-nlayer1-nh8_grugating-false_latent-groupkl_conleninit{10}-conlenrecur{10}clear-alwayslateset_seed0',
# exp_name=f'data_xzero_atari_0404/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_lsd768-nlayer1-nh8_grugating-false_latent-groupkl_conleninit{6}-conlenrecur{6}clear-keepdepth2_seed0',
# exp_name=f'data_xzero_atari_0403/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_lsd768-nlayer1-nh8_grugating-false_latent-groupkl_conleninit{2*num_unroll_steps}-conlenrecur{2*num_unroll_steps}clear-keepdepth0_seed0',
# exp_name=f'data_xzero_atari_0401/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nlayer1-nh8_grugating-false_simnorm_latentw10_pew1e-4_latent-groupkl_soft005_eps20k_nogradscale_gcv5_seed0',
# exp_name=f'data_xzero_atari_0403/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nlayer1-nh8_grugating-false_simnorm_latentw10_pew1e-4_latent-groupkl_soft005_eps20k_nogradscale_gcv5_conleninit{16}-conlenrecur{16}clear-keepdepth1_seed0',
# exp_name=f'data_xzero_atari_0401/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nlayer2-nh8_simnorm_latentw10_pew1e-4_latent-groupkl_soft005_eps20k_nogradscale_gcv5_contextlength10_seed0',
env=dict(
stop_value=int(1e6),
env_name=env_name,
Expand Down
12 changes: 8 additions & 4 deletions zoo/memory/config/memory_xzero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
# key_to_door [2, 60, 120, 250, 500]

# max_env_step = int(3e6)
max_env_step = int(1e6)
# max_env_step = int(1e6)
max_env_step = int(5e5)


# ==== NOTE: 需要设置cfg_memory中的action_shape =====
# ==== NOTE: 需要设置cfg_memory中的policy_entropy_weight =====
Expand Down Expand Up @@ -50,11 +52,13 @@
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
torch.cuda.set_device(4)
torch.cuda.set_device(5)
memory_xzero_config = dict(
# mcts_ctree.py muzero_collector muzero_evaluator
exp_name=f'data_memory_{env_id}_0408/{env_id}_memlen-{memory_length}_xzero_H{num_unroll_steps}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_bs{batch_size}'
f'_eps-20k_seed{seed}_eval{evaluator_env_num}_train-with-episode_conleninit{32}-conlenrecur{32}clear_gamma1.2_nl2-nh8_emd96',
exp_name=f'data_memory_{env_id}_0410/{env_id}_memlen-{memory_length}_xzero_H{num_unroll_steps}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_bs{batch_size}'
f'_eps-20k_seed{seed}_eval{evaluator_env_num}_train-with-episode_conleninit{32}-conlenrecur{32}clear_gamma1_nl6-nh8-emd768_phase3-random-colormap-bce_phase1-fixed-target-pos_random-target-color',
# exp_name=f'data_memory_{env_id}_0410/{env_id}_memlen-{memory_length}_xzero_H{num_unroll_steps}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_bs{batch_size}'
# f'_eps-20k_seed{seed}_eval{evaluator_env_num}_train-with-episode_conleninit{32}-conlenrecur{32}clear_gamma1_nl6-nh8-emd96-true_fixed-colormap-bce_fixed-target-b',
env=dict(
stop_value=int(1e6),
env_id=env_id,
Expand Down
3 changes: 2 additions & 1 deletion zoo/memory/envs/pycolab_tvt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def keep_n_characters_in_grid(grid, character, n, backdrop_char=BACKGROUND):
num_empty_positions = char_positions.shape[0] - n
if num_empty_positions < 0:
raise ValueError("Not enough characters `{}` in grid.".format(character))
empty_pos = np.random.permutation(char_positions)[:num_empty_positions]
# empty_pos = np.random.permutation(char_positions)[:num_empty_positions]
empty_pos = char_positions[:num_empty_positions] # TODO: phase1-fixed-target-pos 在1.(exploration phase) target_color 使用固定的位置

# Remove characters.
grid = [list(row) for row in grid]
Expand Down
15 changes: 10 additions & 5 deletions zoo/memory/envs/pycolab_tvt/visual_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
]

# MAX_FRAMES_PER_PHASE = {"explore": 15, "distractor": 30, "reward": 15}
MAX_FRAMES_PER_PHASE = {"explore": 2, "distractor": 1, "reward": 15}
MAX_FRAMES_PER_PHASE = {"explore": 2, "distractor": 0, "reward": 15}


class Game(game.AbstractGame):
Expand All @@ -76,7 +76,6 @@ def __init__(
final_reward=10.0,
respawn_every=common.DEFAULT_APPLE_RESPAWN_TIME,
crop=True,
# crop=False,
max_frames=MAX_FRAMES_PER_PHASE,
EXPLORE_GRID=PASSIVE_EXPLORE_GRID,
):
Expand All @@ -92,10 +91,12 @@ def __init__(
self._episode_length = sum(self._max_frames.values())
self._num_actions = common.NUM_ACTIONS
self._colours = common.FIXED_COLOURS.copy()
shuffled_symbol_colour_map = common.get_shuffled_symbol_colour_map(rng, SYMBOLS_TO_SHUFFLE) # TODO:b c e (分别对应左 中 右位置) 的颜色随机
# shuffled_symbol_colour_map = {'b': (0, 0, 1000), 'c': (1000, 0, 0), 'e': (0, 1000, 0)} # TODO:phase3-fixed-colormap-bce b c e (分别对应左 中 右位置) 的颜色固定为:蓝色 红色 绿色
print(f'shuffled_symbol_colour_map: {shuffled_symbol_colour_map}')
self._colours.update(
common.get_shuffled_symbol_colour_map(rng, SYMBOLS_TO_SHUFFLE)
shuffled_symbol_colour_map
)

self._extra_observation_fields = ["chapter_reward_as_string"]

@property
Expand Down Expand Up @@ -181,7 +182,11 @@ def make_episode(self):
croppers = common.get_cropper()
else:
croppers = None
target_char = self._rng.choice(SYMBOLS_TO_SHUFFLE)
target_char = self._rng.choice(SYMBOLS_TO_SHUFFLE) # TODO:随机目标颜色
# target_char = 'b' # TODO:固定目标颜色为左上角位置的颜色
print(f"self._rng: {self._rng}")
print(f"symbols_to_shuffle: {SYMBOLS_TO_SHUFFLE}")
print(f"target_char: {target_char}")
return storytelling.Story(
[
lambda: self._make_explore_phase(target_char),
Expand Down

0 comments on commit 5fa006c

Please sign in to comment.