From 6b594f36a231276eb1dd5614fb0ba9ddb58112be Mon Sep 17 00:00:00 2001 From: puyuan1996 <2402552459@qq.com> Date: Sat, 16 Mar 2024 21:57:38 +0800 Subject: [PATCH] polish(pu): polish prepare_obs_stack4_for_gpt --- lzero/model/gpt_models/cfg_atari.py | 2 +- lzero/model/gpt_models/world_model.py | 4 ++-- lzero/policy/muzero_gpt.py | 4 ++-- lzero/policy/utils.py | 5 ++--- zoo/atari/config/atari_xzero_config_stack1.py | 6 +++--- .../config/atari_xzero_config_stack1_debug.py | 16 ++++------------ zoo/atari/envs/atari_lightzero_env.py | 1 - 7 files changed, 14 insertions(+), 24 deletions(-) diff --git a/lzero/model/gpt_models/cfg_atari.py b/lzero/model/gpt_models/cfg_atari.py index d9a1dd94b..59573e2e9 100644 --- a/lzero/model/gpt_models/cfg_atari.py +++ b/lzero/model/gpt_models/cfg_atari.py @@ -91,7 +91,7 @@ # 'action_shape': 18,# TODO:for Seaquest boxing Frostbite # 'action_shape': 9,# TODO:for mspacman # 'action_shape': 4,# TODO:for breakout - 'action_shape': 6,# TODO:for pong qbert + 'action_shape': 6, # TODO:for pong qbert 'max_cache_size':5000, # 'max_cache_size':50000, diff --git a/lzero/model/gpt_models/world_model.py b/lzero/model/gpt_models/world_model.py index b05c5b1d6..d53cf0777 100644 --- a/lzero/model/gpt_models/world_model.py +++ b/lzero/model/gpt_models/world_model.py @@ -704,8 +704,8 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) # 计算重建损失和感知损失 - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) + latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # NOTE: for stack=4 # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # NOTE: for stack=4 diff --git a/lzero/policy/muzero_gpt.py b/lzero/policy/muzero_gpt.py index acfd40cd4..6f27e644d 100644 --- a/lzero/policy/muzero_gpt.py +++ b/lzero/policy/muzero_gpt.py @@ -17,7 +17,7 @@ from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ - prepare_obs, prepare_obs_for_gpt + prepare_obs, prepare_obs_stack4_for_gpt from line_profiler import line_profiler @@ -423,7 +423,7 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni target_reward, target_value, target_policy = target_batch if self._cfg.model.frame_stack_num == 4: - obs_batch, obs_target_batch = prepare_obs_for_gpt(obs_batch_ori, self._cfg) + obs_batch, obs_target_batch = prepare_obs_stack4_for_gpt(obs_batch_ori, self._cfg) else: obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index f10e374ff..09a459491 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -286,7 +286,7 @@ def configure_optimizers( return optimizer -def prepare_obs_for_gpt(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]: +def prepare_obs_stack4_for_gpt_bkp(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]: """ 概述: 为模型准备观测数据,包括: @@ -361,7 +361,7 @@ def prepare_obs_for_gpt(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch return obs_batch, obs_target_batch -def prepare_obs_for_gpt_v2(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]: +def prepare_obs_stack4_for_gpt(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]: obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device).float() obs_batch = obs_batch_ori[:, :cfg.model.frame_stack_num * (cfg.model.image_channel if cfg.model.model_type == 'conv' else cfg.model.observation_shape), ...] @@ -393,7 +393,6 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, """ # Convert the numpy array of original observations to a PyTorch tensor and transfer it to the specified device. # Also, ensure the tensor is of the correct floating-point type for the model. - # obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device).float() obs_batch_ori = torch.from_numpy(obs_batch_ori).to(cfg.device) # Calculate the dimension size to slice based on the model configuration. diff --git a/zoo/atari/config/atari_xzero_config_stack1.py b/zoo/atari/config/atari_xzero_config_stack1.py index ff26f8b1b..209ec1f3c 100644 --- a/zoo/atari/config/atari_xzero_config_stack1.py +++ b/zoo/atari/config/atari_xzero_config_stack1.py @@ -1,13 +1,13 @@ from easydict import EasyDict import torch -torch.cuda.set_device(3) +torch.cuda.set_device(0) # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} -# env_name = 'PongNoFrameskip-v4' +env_name = 'PongNoFrameskip-v4' # env_name = 'MsPacmanNoFrameskip-v4' # env_name = 'QbertNoFrameskip-v4' # env_name = 'SeaquestNoFrameskip-v4' -env_name = 'BreakoutNoFrameskip-v4' # collect_env_steps=5e3 +# env_name = 'BreakoutNoFrameskip-v4' # collect_env_steps=5e3 # env_name = 'BoxingNoFrameskip-v4' # env_name = 'FrostbiteNoFrameskip-v4' diff --git a/zoo/atari/config/atari_xzero_config_stack1_debug.py b/zoo/atari/config/atari_xzero_config_stack1_debug.py index 2408f6800..780dc05bd 100644 --- a/zoo/atari/config/atari_xzero_config_stack1_debug.py +++ b/zoo/atari/config/atari_xzero_config_stack1_debug.py @@ -89,12 +89,12 @@ env=dict( stop_value=int(1e6), env_name=env_name, - # obs_shape=(4, 96, 96), - # obs_shape=(1, 96, 96), + # NOTE: for stack1 observation_shape=(3, 64, 64), gray_scale=False, + # NOTE: for stack4 # observation_shape=(4, 64, 64), # gray_scale=True, @@ -138,19 +138,13 @@ # transformer_start_after_envsteps=int(5e3), num_unroll_steps=num_unroll_steps, model=dict( - # observation_shape=(4, 96, 96), - # frame_stack_num=4, - # observation_shape=(1, 96, 96), - # image_channel=3, - # frame_stack_num=1, - # gray_scale=False, - + # NOTE: for stack1 observation_shape=(3, 64, 64), image_channel=3, frame_stack_num=1, gray_scale=False, - # NOTE: very important + # NOTE: for stack4 # observation_shape=(4, 64, 64), # image_channel=1, # frame_stack_num=4, @@ -168,8 +162,6 @@ # reward_support_size=21, # value_support_size=21, # support_scale=10, - embedding_dim=1024, - # embedding_dim=256, ), use_priority=False, cuda=True, diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 3590100b3..b04238574 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -27,7 +27,6 @@ class AtariLightZeroEnv(BaseEnv): frame_skip=4, episode_life=True, clip_rewards=True, - # channel_last=True, channel_last=False, render_mode_human=False, scale=True,