Skip to content

Commit

Permalink
polish(pu): polish prepare_obs_stack4_for_gpt
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Mar 16, 2024
1 parent 8354885 commit 6b594f3
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 24 deletions.
2 changes: 1 addition & 1 deletion lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
# 'action_shape': 18,# TODO:for Seaquest boxing Frostbite
# 'action_shape': 9,# TODO:for mspacman
# 'action_shape': 4,# TODO:for breakout
'action_shape': 6,# TODO:for pong qbert
'action_shape': 6, # TODO:for pong qbert

'max_cache_size':5000,
# 'max_cache_size':50000,
Expand Down
4 changes: 2 additions & 2 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,8 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -
reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)

# 计算重建损失和感知损失
latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images)
perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images)
latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1
perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1

# latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # NOTE: for stack=4
# perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # NOTE: for stack=4
Expand Down
4 changes: 2 additions & 2 deletions lzero/policy/muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lzero.model import ImageTransforms
from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \
DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \
prepare_obs, prepare_obs_for_gpt
prepare_obs, prepare_obs_stack4_for_gpt
from line_profiler import line_profiler


Expand Down Expand Up @@ -423,7 +423,7 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni
target_reward, target_value, target_policy = target_batch

if self._cfg.model.frame_stack_num == 4:
obs_batch, obs_target_batch = prepare_obs_for_gpt(obs_batch_ori, self._cfg)
obs_batch, obs_target_batch = prepare_obs_stack4_for_gpt(obs_batch_ori, self._cfg)
else:
obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg)

Expand Down
5 changes: 2 additions & 3 deletions lzero/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def configure_optimizers(

return optimizer

def prepare_obs_for_gpt(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]:
def prepare_obs_stack4_for_gpt_bkp(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]:
"""
概述:
为模型准备观测数据,包括:
Expand Down Expand Up @@ -361,7 +361,7 @@ def prepare_obs_for_gpt(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch

return obs_batch, obs_target_batch

def prepare_obs_for_gpt_v2(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]:
def prepare_obs_stack4_for_gpt(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]:
obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device).float()
obs_batch = obs_batch_ori[:, :cfg.model.frame_stack_num * (cfg.model.image_channel if cfg.model.model_type == 'conv' else cfg.model.observation_shape), ...]

Expand Down Expand Up @@ -393,7 +393,6 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor,
"""
# Convert the numpy array of original observations to a PyTorch tensor and transfer it to the specified device.
# Also, ensure the tensor is of the correct floating-point type for the model.
# obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device).float()
obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device)

# Calculate the dimension size to slice based on the model configuration.
Expand Down
6 changes: 3 additions & 3 deletions zoo/atari/config/atari_xzero_config_stack1.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from easydict import EasyDict
import torch
torch.cuda.set_device(3)
torch.cuda.set_device(0)

# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...}
# env_name = 'PongNoFrameskip-v4'
env_name = 'PongNoFrameskip-v4'
# env_name = 'MsPacmanNoFrameskip-v4'
# env_name = 'QbertNoFrameskip-v4'
# env_name = 'SeaquestNoFrameskip-v4'
env_name = 'BreakoutNoFrameskip-v4' # collect_env_steps=5e3
# env_name = 'BreakoutNoFrameskip-v4' # collect_env_steps=5e3
# env_name = 'BoxingNoFrameskip-v4'
# env_name = 'FrostbiteNoFrameskip-v4'

Expand Down
16 changes: 4 additions & 12 deletions zoo/atari/config/atari_xzero_config_stack1_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@
env=dict(
stop_value=int(1e6),
env_name=env_name,
# obs_shape=(4, 96, 96),
# obs_shape=(1, 96, 96),

# NOTE: for stack1
observation_shape=(3, 64, 64),
gray_scale=False,

# NOTE: for stack4
# observation_shape=(4, 64, 64),
# gray_scale=True,

Expand Down Expand Up @@ -138,19 +138,13 @@
# transformer_start_after_envsteps=int(5e3),
num_unroll_steps=num_unroll_steps,
model=dict(
# observation_shape=(4, 96, 96),
# frame_stack_num=4,
# observation_shape=(1, 96, 96),
# image_channel=3,
# frame_stack_num=1,
# gray_scale=False,

# NOTE: for stack1
observation_shape=(3, 64, 64),
image_channel=3,
frame_stack_num=1,
gray_scale=False,

# NOTE: very important
# NOTE: for stack4
# observation_shape=(4, 64, 64),
# image_channel=1,
# frame_stack_num=4,
Expand All @@ -168,8 +162,6 @@
# reward_support_size=21,
# value_support_size=21,
# support_scale=10,
embedding_dim=1024,
# embedding_dim=256,
),
use_priority=False,
cuda=True,
Expand Down
1 change: 0 additions & 1 deletion zoo/atari/envs/atari_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class AtariLightZeroEnv(BaseEnv):
frame_skip=4,
episode_life=True,
clip_rewards=True,
# channel_last=True,
channel_last=False,
render_mode_human=False,
scale=True,
Expand Down

0 comments on commit 6b594f3

Please sign in to comment.