Skip to content

Commit

Permalink
polish(pu): polish xzero configs
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Mar 9, 2024
1 parent 32aaf2a commit a7d5d21
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 39 deletions.
6 changes: 4 additions & 2 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:1',
"device": 'cuda:0',
# "device": 'cpu',
# 'support_size': 21,
'support_size': 601,
Expand All @@ -106,7 +106,9 @@

# 'latent_recon_loss_weight':0.,
# 'perceptual_loss_weight':0.,
'policy_entropy_weight': 1e-4,
'policy_entropy_weight': 0,
# 'policy_entropy_weight': 1e-4,

}
from easydict import EasyDict
cfg = EasyDict(cfg)
Expand Down
8 changes: 4 additions & 4 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,11 +1005,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -
reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)

# Calculate the reconstruction loss
# latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4
# perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs
latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4
perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs

latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1
perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1
# latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1
# perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1

# latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype)
# perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype)
Expand Down
4 changes: 2 additions & 2 deletions lzero/model/muzero_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def __init__(
embedding_dim=cfg.world_model.embed_dim,
)
# Instantiate the decoder
# decoder_network = LatentDecoder(embedding_dim=cfg.world_model.embed_dim, output_shape=(4, 64, 64)) # TODO: For K=4
decoder_network = LatentDecoder(embedding_dim=cfg.world_model.embed_dim, output_shape=(3, 64, 64)) # TODO: For K=1
decoder_network = LatentDecoder(embedding_dim=cfg.world_model.embed_dim, output_shape=(4, 64, 64)) # TODO: For K=4
# decoder_network = LatentDecoder(embedding_dim=cfg.world_model.embed_dim, output_shape=(3, 64, 64)) # TODO: For K=1


Encoder = Encoder(cfg.tokenizer.encoder)
Expand Down
4 changes: 2 additions & 2 deletions lzero/policy/muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,8 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni
'latent_recon_loss':latent_recon_loss,
'perceptual_loss':perceptual_loss,
'policy_loss': policy_loss,
'orig_policy_loss':orig_policy_loss.item(),
'policy_entropy':policy_entropy.item(),
'orig_policy_loss':orig_policy_loss,
'policy_entropy':policy_entropy,

'target_policy_entropy': average_target_policy_entropy,
# 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1),
Expand Down
9 changes: 6 additions & 3 deletions zoo/atari/config/atari_xzero_config_stack1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict
import torch
torch.cuda.set_device(1)
torch.cuda.set_device(0)

# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...}
env_name = 'PongNoFrameskip-v4'
Expand Down Expand Up @@ -81,7 +81,8 @@
# atari env action space
# game_buffer_muzero_gpt task_id
# TODO: muzero_gpt_model.py world_model.py (3,64,64)
exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_simnorm_latentw10_seed0',
exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_simnorm_latentw10_pew0_seed0',
# exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_simnorm_latentw10_pew1e-4_seed0',

# exp_name=f'data_xzero_0306/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh4_fixroot_head-2-layer_mantrans-nobatch_seed0',

Expand Down Expand Up @@ -264,4 +265,6 @@
# def run(max_env_step: int):
# train_muzero_gpt([main_config, create_config], seed=0, model_path=main_config.policy.model_path, max_env_step=max_env_step)
# import cProfile
# cProfile.run(f"run({100000})", filename="pong_xzero_cprofile_100k_envstep", sort="cumulative")
# cProfile.run(f"run({100000})", filename="pong_xzero_cprofile_100k_envstep", sort="cumulative")

# python -m line_profiler /mnt/afs/niuyazhe/code/LightZero/atari_xzero_config_stack1.py.lprof > atari_xzero_config_stack1.py.lprof.txt
31 changes: 5 additions & 26 deletions zoo/atari/config/atari_xzero_config_stack4.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,16 @@
collector_env_num = 8
n_episode = 8
evaluator_env_num = 1
update_per_collect = 1000
# update_per_collect = None
model_update_ratio = 0.25

# update_per_collect = 500


# collector_env_num = 4
# n_episode = 4
# evaluator_env_num = 1
# update_per_collect = 500

# collector_env_num = 2
# n_episode = 2
# evaluator_env_num = 1
# update_per_collect = 250

# update_per_collect = 2000

# update_per_collect = None
# model_update_ratio = 0.25
num_simulations = 50
# num_simulations = 25
update_per_collect = 1000

# TODO: debug
# num_simulations = 1
max_env_step = int(10e6)
reanalyze_ratio = 0
# reanalyze_ratio = 0.05

batch_size = 64 # for num_head=2, emmbding_dim=128
num_unroll_steps = 5
max_env_step = int(10e6)
num_simulations = 50


# for debug
Expand All @@ -80,7 +59,7 @@
# TODO:
# muzero_gpt_model.py world_model.py stack (4,64,64)
# muzero: mcts_ctree, muzero_collector: empty_cache
exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack4_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh4_fixroot_seed0',
exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack4_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_simnorm_latentw10_pew0_seed0',

# exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack4_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh4_fixroot_head-2-layer_mantrans-nobatch_seed0',

Expand Down

0 comments on commit a7d5d21

Please sign in to comment.