From 1c13700db4e5135de150343740437705ebf187b9 Mon Sep 17 00:00:00 2001 From: puyuan1996 <2402552459@qq.com> Date: Fri, 15 Mar 2024 00:11:10 +0800 Subject: [PATCH] polish(pu): polish world model, use latent-groupkl-loss, no latent grad_scale, fixed-act-embedding --- lzero/entry/train_muzero_gpt.py | 14 +- lzero/model/common.py | 43 - lzero/model/gpt_models/cfg_atari.py | 16 +- lzero/model/gpt_models/utils.py | 44 - lzero/model/gpt_models/world_model.py | 775 ++++------- .../gpt_models/world_model_bkp20240316.py | 1157 +++++++++++++++++ lzero/model/muzero_gpt_model.py | 4 +- lzero/policy/muzero_gpt.py | 14 +- zoo/atari/config/atari_xzero_config_stack1.py | 96 +- .../config/atari_xzero_config_stack1_debug.py | 11 +- 10 files changed, 1430 insertions(+), 744 deletions(-) create mode 100644 lzero/model/gpt_models/world_model_bkp20240316.py diff --git a/lzero/entry/train_muzero_gpt.py b/lzero/entry/train_muzero_gpt.py index 5ce22f42c..b74c2cf58 100644 --- a/lzero/entry/train_muzero_gpt.py +++ b/lzero/entry/train_muzero_gpt.py @@ -156,14 +156,20 @@ def train_muzero_gpt( else: collect_kwargs['epsilon'] = 0.0 + # policy.last_batch_obs = torch.zeros([len(evaluator_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device) + # policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))] # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) # Evaluate policy performance. if evaluator.should_eval(learner.train_iter): + policy.last_batch_obs = torch.zeros([len(evaluator_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device) + policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))] stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) if stop: break + policy.last_batch_obs = torch.zeros([len(collector_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device) + policy.last_batch_action = [-1 for _ in range(len(collector_env_cfg))] # Collect data by default config n_sample/n_episode. new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) if cfg.policy.update_per_collect is None: @@ -214,19 +220,11 @@ def train_muzero_gpt( policy._target_model.world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad print('sample target_model past_keys_values_cache.clear()') - # del policy._learn_model.world_model.keys_values_wm - - # TODO: for batch world model ,to improve kv reuse, we can donot reset - # policy._learn_model.world_model.past_keys_values_cache.clear() # very important - # policy._eval_model.world_model.past_keys_values_cache.clear() # very important - policy._collect_model.world_model.past_keys_values_cache.clear() # very important policy._collect_model.world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad torch.cuda.empty_cache() # TODO: NOTE - policy.last_batch_obs = torch.zeros([len(collector_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device) - policy.last_batch_action = [-1 for _ in range(len(collector_env_cfg))] # if collector.envstep > 0: # # TODO: only for debug diff --git a/lzero/model/common.py b/lzero/model/common.py index 3f7f1e5aa..3c2b6f673 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -139,25 +139,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output -# EZ original -# def renormalize(inputs: torch.Tensor, first_dim: int = 1) -> torch.Tensor: -# """ -# Overview: -# Normalize the input data using the max-min-normalization. -# Arguments: -# - inputs (:obj:`torch.Tensor`): The input data needs to be normalized. -# - first_dim (:obj:`int`): The first dimension of flattening the input data. -# Returns: -# - output (:obj:`torch.Tensor`): The normalized data. -# """ -# if first_dim < 0: -# first_dim = len(inputs.shape) + first_dim -# flat_input = inputs.view(*inputs.shape[:first_dim], -1) -# max_val = torch.max(flat_input, first_dim, keepdim=True).values -# min_val = torch.min(flat_input, first_dim, keepdim=True).values -# flat_input = (flat_input - min_val) / (max_val - min_val) - -# return flat_input.view(*input.shape) def renormalize_min_max(x): # min-max # x is a 2D tensor of shape (batch_size, num_features) @@ -171,30 +152,6 @@ def renormalize_min_max(x): # min-max return x_scaled -# def renormalize(x): # z-score -# # x is a 2D tensor of shape (batch_size, num_features) -# # Compute the mean and standard deviation for each feature across the batch -# mean = torch.mean(x, dim=0, keepdim=True) -# std = torch.std(x, dim=0, keepdim=True) - -# # Apply z-score normalization -# x_normalized = (x - mean) / (std + 1e-8) # Add a small epsilon to avoid division by zero - -# return x_normalized - -# def renormalize(x): # robust scaling -# # x is a 2D tensor of shape (batch_size, num_features) -# # Compute the 1st and 3rd quartile -# q1 = torch.quantile(x, 0.25, dim=0, keepdim=True) -# q3 = torch.quantile(x, 0.75, dim=0, keepdim=True) - -# # Compute the interquartile range (IQR) -# iqr = q3 - q1 - -# # Apply robust scaling -# x_scaled = (x - q1) / (iqr + 1e-8) # Again, add epsilon to avoid division by zero - -# return x_scaled class SimNorm(nn.Module): """ diff --git a/lzero/model/gpt_models/cfg_atari.py b/lzero/model/gpt_models/cfg_atari.py index 414ca45de..7a2f99c78 100644 --- a/lzero/model/gpt_models/cfg_atari.py +++ b/lzero/model/gpt_models/cfg_atari.py @@ -81,17 +81,17 @@ 'embed_pdrop': 0.1, 'resid_pdrop': 0.1, 'attn_pdrop': 0.1, - "device": 'cuda:0', + "device": 'cuda:3', # "device": 'cpu', # 'support_size': 21, 'support_size': 601, # 'action_shape': 18,# TODO:for multi-task - # 'action_shape': 18,# TODO:for Seaquest boxing + # 'action_shape': 18,# TODO:for Seaquest boxing Frostbite # 'action_shape': 9,# TODO:for mspacman - # 'action_shape': 4,# TODO:for breakout - 'action_shape': 6,# TODO:for pong qbert + 'action_shape': 4,# TODO:for breakout + # 'action_shape': 6,# TODO:for pong qbert 'max_cache_size':5000, # 'max_cache_size':50000, @@ -106,8 +106,12 @@ # 'latent_recon_loss_weight':0., # 'perceptual_loss_weight':0., - 'policy_entropy_weight': 0, - # 'policy_entropy_weight': 1e-4, + + # 'policy_entropy_weight': 0, + 'policy_entropy_weight': 1e-4, + + 'predict_latent_loss_type': 'group_kl', # 'mse' + # 'predict_latent_loss_type': 'mse', # 'mse' } from easydict import EasyDict diff --git a/lzero/model/gpt_models/utils.py b/lzero/model/gpt_models/utils.py index 494ad02ea..a467a895c 100644 --- a/lzero/model/gpt_models/utils.py +++ b/lzero/model/gpt_models/utils.py @@ -11,44 +11,6 @@ from .episode import Episode -# def configure_optimizer(model, learning_rate, weight_decay, *blacklist_module_names): -# """Credits to https://github.com/karpathy/minGPT""" -# # separate out all parameters to those that will and won't experience regularizing weight decay -# decay = set() -# no_decay = set() -# whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d) -# blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) -# for mn, m in model.named_modules(): -# for pn, p in m.named_parameters(): -# fpn = '%s.%s' % (mn, pn) if mn else pn # full param name -# if any([fpn.startswith(module_name) for module_name in blacklist_module_names]): -# no_decay.add(fpn) -# elif 'bias' in pn: -# # all biases will not be decayed -# no_decay.add(fpn) -# elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): -# # weights of whitelist modules will be weight decayed -# decay.add(fpn) -# elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): -# # weights of blacklist modules will NOT be weight decayed -# no_decay.add(fpn) - -# # validate that we considered every parameter -# param_dict = {pn: p for pn, p in model.named_parameters()} -# inter_params = decay & no_decay -# union_params = decay | no_decay -# assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!" -# assert len( -# param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" - -# # create the pytorch optimizer object -# optim_groups = [ -# {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, -# {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, -# ] -# optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate) -# return optimizer - from lzero.model.common import RepresentationNetwork def init_weights(module): @@ -114,7 +76,6 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg # # self.ends_loss_weight = 1. # self.ends_loss_weight = 0. - # self.obs_loss_weight = 0.1 self.obs_loss_weight = 10 # self.obs_loss_weight = 2 @@ -125,9 +86,6 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg # self.ends_loss_weight = 1. self.ends_loss_weight = 0. - # self.latent_kl_loss_weight = 0.1 # for lunarlander - self.latent_kl_loss_weight = 0. # for lunarlander - self.latent_recon_loss_weight = latent_recon_loss_weight self.perceptual_loss_weight = perceptual_loss_weight # self.latent_recon_loss_weight = 0.1 @@ -146,8 +104,6 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg self.loss_total += self.value_loss_weight * v elif k == 'loss_ends': self.loss_total += self.ends_loss_weight * v - elif k == 'latent_kl_loss': - self.loss_total += self.latent_kl_loss_weight * v elif k == 'latent_recon_loss': self.loss_total += self.latent_recon_loss_weight * v elif k == 'perceptual_loss': diff --git a/lzero/model/gpt_models/world_model.py b/lzero/model/gpt_models/world_model.py index 3ed2f6926..37381f4c9 100644 --- a/lzero/model/gpt_models/world_model.py +++ b/lzero/model/gpt_models/world_model.py @@ -18,6 +18,7 @@ import torch.nn as nn import torch.nn.functional as F import torchvision +import collections from .kv_caching import KeysValues from .slicer import Embedder, Head, ActEmbedder @@ -26,15 +27,13 @@ from .utils import LossWithIntermediateLosses, init_weights from ding.torch_utils import to_device - -# from memory_profiler import profile from line_profiler import line_profiler import hashlib class SimNorm(nn.Module): """ - Simplicial normalization. - Adapted from https://arxiv.org/abs/2204.00616. + 简单单位向量归一化。 + 改编自 https://arxiv.org/abs/2204.00616. """ def __init__(self, simnorm_dim): @@ -43,7 +42,7 @@ def __init__(self, simnorm_dim): def forward(self, x): shp = x.shape - # Ensure that there is at least one simplex to normalize across. + # 确保至少有一个单纯形用于归一化。 if shp[1] != 0: x = x.view(*shp[:-1], -1, self.dim) x = F.softmax(x, dim=-1) @@ -54,9 +53,7 @@ def forward(self, x): def __repr__(self): return f"SimNorm(dim={self.dim})" -# def quantize_state(state, num_buckets=1000): def quantize_state(state, num_buckets=15): -# def quantize_state(state, num_buckets=10): """ 量化状态向量。 参数: @@ -65,40 +62,34 @@ def quantize_state(state, num_buckets=15): 返回: 量化后的状态向量的哈希值。 """ - # 使用np.digitize将状态向量的每个维度值映射到num_buckets个桶中 + # 使用np.digitize将状态向量的每个维度值映射到num_buckets个桶中 quantized_state = np.digitize(state, bins=np.linspace(0, 1, num=num_buckets)) # 使用更稳定的哈希函数 quantized_state_bytes = quantized_state.tobytes() hash_object = hashlib.sha256(quantized_state_bytes) return hash_object.hexdigest() - - @dataclass class WorldModelOutput: output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor + logits_observations: torch.FloatTensor logits_rewards: torch.FloatTensor logits_ends: torch.FloatTensor - logits_policy: torch.FloatTensor logits_value: torch.FloatTensor - class WorldModel(nn.Module): def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: super().__init__() - # config.max_tokens = int(2*50) # TODO - self.tokenizer = tokenizer self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size self.config = config self.policy_entropy_weight = config.policy_entropy_weight + self.predict_latent_loss_type = config.predict_latent_loss_type self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 + self.num_observations_tokens = config.tokens_per_block - 1 self.latent_recon_loss_weight = config.latent_recon_loss_weight self.perceptual_loss_weight = config.perceptual_loss_weight @@ -109,28 +100,21 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer self.max_cache_size = config.max_cache_size self.env_num = config.env_num self.num_layers = config.num_layers - self.sim_norm = SimNorm(simnorm_dim=8) + self.sim_norm = SimNorm(simnorm_dim=8) # NOTE all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 + all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 + act_tokens_pattern[-1] = 1 # 0,...,0,1 latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - # current latent state's policy value + # 当前latent state的策略值 value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - # TODO self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - self.embedder = Embedder( max_blocks=config.max_blocks, block_masks=[act_tokens_pattern, latent_state_pattern], @@ -138,9 +122,9 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer ) self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - # self.act_embedding_table.weight.requires_grad = False # TODO: 测试效果 + self.act_embedding_table.weight.requires_grad = False # NOTE: 对于离散动作,使用fixed_act_embedding,效率更高 - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 + self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 self.head_rewards = Head( max_blocks=config.max_blocks, @@ -151,48 +135,21 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer nn.Linear(config.embed_dim, self.support_size) ) ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO + self.head_observations = Head( # TODO max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 + block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 head_module=nn.Sequential( nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - - ###### TODO: 2层的性能, LeakyReLU->GELU ###### - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.LeakyReLU(negative_slope=0.01), # TODO: 2 nn.GELU(), nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - # nn.Sigmoid(), # 这里添加Sigmoid函数 TODO self.sim_norm, ) ) self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) + max_blocks=config.max_blocks, + block_mask=value_policy_tokens_pattern, # [0,...,1,0] + head_module=nn.Sequential( # (8, 5, 128) nn.Linear(config.embed_dim, config.embed_dim), - # nn.LeakyReLU(negative_slope=0.01), # TODO: 2 nn.GELU(), nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape ) @@ -201,41 +158,17 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer max_blocks=config.max_blocks, block_mask=value_policy_tokens_pattern, head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.LeakyReLU(negative_slope=0.01), # TODO: 2 + nn.Linear(config.embed_dim, config.embed_dim), nn.GELU(), nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape ) ) - ###### TODO: 单层的性能 ###### - # self.head_observations = Head( # TODO - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - # ) - # ) - # self.head_policy = Head( - # max_blocks=config.max_blocks, - # block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - # head_module=nn.Sequential( # (8, 5, 128) - # nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - # ) - # ) - # self.head_value = Head( - # max_blocks=config.max_blocks, - # block_mask=value_policy_tokens_pattern, - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - # ) - # ) - self.apply(init_weights) - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. + last_linear_layer_init_zero = True # TODO: 有利于收敛速度。 if last_linear_layer_init_zero: + # 将头部模块的最后一个线性层的权重和偏置初始化为零 for _, layer in enumerate(reversed(self.head_value.head_module)): if isinstance(layer, nn.Linear): nn.init.zeros_(layer.weight) @@ -247,50 +180,28 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer nn.init.zeros_(layer.weight) if layer.bias is not None: nn.init.zeros_(layer.bias) - break + break for _, layer in enumerate(reversed(self.head_observations.head_module)): if isinstance(layer, nn.Linear): nn.init.zeros_(layer.weight) - # layer.weight.data.fill_(0.5) # TODO:bug if layer.bias is not None: nn.init.zeros_(layer.bias) break - - import collections + # 使用collections.OrderedDict作为缓存结构,可以维持插入顺序 self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - # TODO: Transformer更新后应该清除缓存 + # TODO: Transformer更新后应该清除缓存 self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.config.max_tokens) self.keys_values_wm_list = [] self.keys_values_wm_size_list = [] - - if self.num_observations_tokens==16: # k=16 + if self.num_observations_tokens == 16: # k=16 self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - # self.proj_hid = 1024 - # self.proj_out = 1024 - # self.pred_hid = 512 - # self.pred_out = 1024 - # activation = nn.ReLU() - # self.projection = nn.Sequential( - # nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - # nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - # nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - # ) - # self.prediction_head = nn.Sequential( - # nn.Linear(self.proj_out, self.pred_hid), - # nn.BatchNorm1d(self.pred_hid), - # activation, - # nn.Linear(self.pred_hid, self.pred_out), - # ) + elif self.num_observations_tokens == 1: # K=1 + self.projection_input_dim = self.obs_per_embdding_dim # for atari #TODO + self.hit_count = 0 self.total_query_count = 0 self.length3_context_cnt = 0 @@ -299,63 +210,45 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer self.root_total_query_cnt = 0 self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - - def __repr__(self) -> str: - return "world_model" - - #@profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False, kvcache_independent=False) -> WorldModelOutput: - + def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, kvcache_independent=False) -> WorldModelOutput: if kvcache_independent: + # 根据past_keys_values获取每个样本的步骤数 prev_steps = 0 if past_keys_values is None else [past_kv.size for past_kv in past_keys_values] prev_steps = torch.tensor(prev_steps, device=self.device) - # 我们需要为每个样本生成一个序列的步骤indices,然后获取它们的位置嵌入 - # 首先扩展prev_steps至(num_steps, batch_size),这里num_steps=1 - # prev_steps = prev_steps.unsqueeze(0) - else: prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - if len(obs_embeddings.shape)==2: + if len(obs_embeddings.shape) == 2: obs_embeddings = obs_embeddings.unsqueeze(1) num_steps = obs_embeddings.size(1) # (B, T, E) if kvcache_independent: # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device).unsqueeze(1) steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device) - - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) # 获取位置嵌入 position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) + # 将位置嵌入reshape回(batch_size, num_steps, embedding_dim) position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 + # 将位置嵌入加到obs_embeddings上 sequences = obs_embeddings + position_embeddings else: sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - if len(act_tokens.shape)==3: + if len(act_tokens.shape) == 3: act_tokens = act_tokens.squeeze(1) num_steps = act_tokens.size(1) # (B, T) act_embeddings = self.act_embedding_table(act_tokens) if kvcache_independent: # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device).unsqueeze(1) steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device) - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) # 获取位置嵌入 position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) + # 将位置嵌入reshape回(batch_size, num_steps, embedding_dim) position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 + # 将位置嵌入加到obs_embeddings上 sequences = act_embeddings + position_embeddings else: sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) @@ -363,20 +256,19 @@ def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysV obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss + if len(obs_embeddings.shape) == 3: # for batch compute loss obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) # L(k+1) - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) + # 根据动作tokens生成动作嵌入 + act_embeddings = self.act_embedding_table(act_tokens) # (B, L, 1) -> (B, L, 1, E) - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 + # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) + # 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) + # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs, act, obs, act, ..., 这样的排列顺序。 B, L, K, E = obs_embeddings.size() - # 初始化一个新的空tensor,用于存放最终的拼接结果 obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) @@ -387,70 +279,34 @@ def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysV act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) + obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) # 将结果填充到最终的tensor中 obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - # Add positional embeddings + # 添加位置嵌入 sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - # print('transformer forward begin') 函数里面更新了update past_keys_values if kvcache_independent: x = [] for k, past_kv in enumerate(past_keys_values): x.append(self.transformer(sequences[k].unsqueeze(0), past_kv)) - x = torch.cat(x, dim=0) - # TODO: 在collect时,是一步一步的 obs act 传入的 - # prev_steps = prev_steps//1 + x = torch.cat(x, dim=0) else: x = self.transformer(sequences, past_keys_values) - # print('transformer forward done') + # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - # logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) # TODO: root reward value return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] @torch.no_grad() - #@profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: + def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: if isinstance(obs_act_dict, dict): observations = obs_act_dict['obs'] buffer_action = obs_act_dict['action'] @@ -459,43 +315,39 @@ def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> observations = obs_act_dict buffer_action = None current_obs = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) if current_obs is not None: - current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(current_obs, should_preprocess=True) # (B, C, H, W) -> (B, K, E) + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(current_obs, should_preprocess=True) # (B, C, H, W) -> (B, K, E) self.latent_state = current_obs_embeddings - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action, current_obs_embeddings) + outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action, current_obs_embeddings) else: self.latent_state = obs_embeddings - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action, None) - + outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action, None) return outputs_wm, self.latent_state @torch.no_grad() - #@profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None, current_obs_embeddings=None) -> torch.FloatTensor: + def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None, current_obs_embeddings=None) -> torch.FloatTensor: n, num_observations_tokens, _ = latent_state.shape if n <= self.env_num: if buffer_action is None: # MCTS root节点: 需要准确的估计 value, policy_logits, 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 self.keys_values_wm = self.transformer.generate_empty_keys_values(n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) + outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, kvcache_independent=False) self.keys_values_wm_size_list = [1 for i in range(n)] # 复制单个环境对应的 keys_values_wm 并存储 self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 + state_single_env = latent_state[i] # 获取单个环境的 latent state quantized_state = state_single_env.detach().cpu().numpy() cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) self.root_total_query_cnt += 1 if cache_key not in self.past_keys_values_cache: self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) @@ -506,23 +358,21 @@ def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent print(f'root_hit_ratio:{root_hit_ratio}') print(f'root_hit_cnt:{self.root_hit_cnt}') print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}') - if self.past_keys_values_cache[cache_key].size>1: - print(f'=='*20) + if self.past_keys_values_cache[cache_key].size > 1: + print(f'==' * 20) print(f'NOTE: root_hit find size > 1') - print(f'=='*20) + print(f'==' * 20) elif current_obs_embeddings is not None: if max(buffer_action) == -1: - # first step in one episode - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=8, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm, is_root=False) + # 一集的第一步 + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm) # 复制单个环境对应的 keys_values_wm 并存储 self.update_cache(current_obs_embeddings) else: - # self.retrieve_or_generate_kvcache(latent_state, current_obs.shape[0]) # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - # ready_env_num = latent_state.shape[0] ready_env_num = current_obs_embeddings.shape[0] self.keys_values_wm_list = [] self.keys_values_wm_size_list = [] @@ -535,139 +385,57 @@ def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent if matched_value is not None: # 如果找到匹配的值,将其添加到列表中 self.root_hit_cnt += 1 - if self.root_total_query_cnt>0 and self.root_total_query_cnt%1000==0: + if self.root_total_query_cnt > 0 and self.root_total_query_cnt % 1000 == 0: root_hit_ratio = self.root_hit_cnt / self.root_total_query_cnt print('root_total_query_cnt:', self.root_total_query_cnt) print(f'root_hit_ratio:{root_hit_ratio}') print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}') - if self.past_keys_values_cache[cache_key].size>=7: - print(f'=='*20) + if self.past_keys_values_cache[cache_key].size >= 7: + print(f'==' * 20) print(f'NOTE: root_hit find size >= 7') - print(f'=='*20) + print(f'==' * 20) # 这里需要deepcopy因为在transformer的forward中会原地修改matched_value self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) self.keys_values_wm_size_list.append(matched_value.size) else: - # use zero reset + # 使用零值重置 self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - # outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, past_keys_values=self.keys_values_wm_single_env) self.keys_values_wm_list.append(self.keys_values_wm_single_env) self.keys_values_wm_size_list.append(1) - - # print(f'NOTE: root {self.keys_values_wm_size_list}') - # print(f'=='*20) + # 输入self.keys_values_wm_list,输出为self.keys_values_wm self.trim_and_pad_kv_cache() buffer_action = buffer_action[:ready_env_num] buffer_action = torch.from_numpy(np.array(buffer_action)).to(latent_state.device) act_tokens = buffer_action.unsqueeze(-1) - # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (latent_state, act_tokens)}, past_keys_values=self.keys_values_wm, is_root=False) - outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, is_root=False) - - outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm, is_root=False) + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm) + + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm) # 复制单个环境对应的 keys_values_wm 并存储 self.update_cache(current_obs_embeddings) - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value + elif n == int(256): + # TODO: n=256 表示训练 tokenizer, 不需要计算target value self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - elif n > self.env_num and n != int(256) and buffer_action is not None: - # train时计算target value - # TODO: n=256 means train tokenizer + outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm) + elif n > self.env_num and n != int(256) and buffer_action is not None: + # 训练时计算 target value # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) act_tokens = rearrange(buffer_action, 'b l -> b l 1') - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (latent_state, act_tokens)}, is_root=False) - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # #@profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - - @torch.no_grad() - # #@profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps act_tokens = torch.cat((act_tokens, last_steps), dim=1) - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (latent_state, act_tokens)}) # 选择每个样本的最后一步 last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 @@ -682,210 +450,143 @@ def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_st outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) return outputs_wm - - @torch.no_grad() - # #@profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor, reset_indices=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - if reset_indices is None: - self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - else: - for i in reset_indices: - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state[i].unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False, kvcache_independent=False) - self.keys_values_wm_list[i] = self.keys_values_wm_single_env - self.keys_values_wm_size_list[i] = 1 - return None - @torch.no_grad() - #@profile def forward_initial_inference(self, obs_act_dict): - # if isinstance(obs_act_dict, dict): - # obs = obs_act_dict['obs'] - # else: - # obs = obs_act_dict - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # root节点也有context - + outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # root节点也有context return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value """ - 假设env_num=8 - 8个环境的kv_cache单独存储与寻找,都存储在一个dict中,在recurrent_inference时, - 由于不同环境找到的kv_cache的size不同,先根据最大size对kv_cache在前部补零,然后组成batch_size的kv_cache - 其内部也是通过batch执行transformer forward的推理 + 假设env_num=8 + 8个环境的kv_cache单独存储与寻找,都存储在一个dict中,在recurrent_inference时, + 由于不同环境找到的kv_cache的size不同,先根据最大size对kv_cache在前部补零,然后组成batch_size的kv_cache + 其内部也是通过batch执行transformer forward的推理 """ @torch.no_grad() - #@profile def forward_recurrent_inference(self, state_action_history): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - # latest_state = state_action_history[-1][0] - # action = state_action_history[-1][-1] + # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 + # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 + # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 + # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? latest_state, action = state_action_history[-1] - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 + # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 ready_env_num = latest_state.shape[0] self.keys_values_wm_list = [] self.keys_values_wm_size_list = [] - self.retrieve_or_generate_kvcache(latest_state, ready_env_num) + self.retrieve_or_generate_kvcache(latest_state, ready_env_num) latent_state_list = [] - # output_sequence_list, latent_state_list = [], [] - - # reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens] - # self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices) - - # token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - # token = token.reshape(-1, 1).to(self.device) # (B, 1) - - # token = torch.tensor(action, dtype=torch.long).reshape(-1, 1).to(self.device) token = action.reshape(-1, 1) - - # print(self.keys_values_wm_size_list) # 获取self.keys_values_wm_size_list的最小值min_size min_size = min(self.keys_values_wm_size_list) if min_size >= self.config.max_tokens - 5: self.length3_context_cnt += len(self.keys_values_wm_size_list) if min_size >= 3: self.length2_context_cnt += len(self.keys_values_wm_size_list) - # if self.total_query_count>0 and self.total_query_count%1==0: - if self.total_query_count>0 and self.total_query_count%10000==0: - self.hit_freq = self.hit_count/(self.total_query_count) - # print('hit_freq:', self.hit_freq) - # print('hit_count:', self.hit_count) + # 打印统计信息 + if self.total_query_count > 0 and self.total_query_count % 10000 == 0: + self.hit_freq = self.hit_count / (self.total_query_count) print('total_query_count:', self.total_query_count) - # 如果总查询次数大于0,计算并打印cnt的比率 + # 如果总查询次数大于0,计算并打印cnt的比率 length3_context_cnt_ratio = self.length3_context_cnt / self.total_query_count print('>=3 node context_cnt:', self.length3_context_cnt) print('>=3 node context_cnt_ratio:', length3_context_cnt_ratio) length2_context_cnt_ratio = self.length2_context_cnt / self.total_query_count print('>=2 node context_cnt_ratio:', length2_context_cnt_ratio) print('>=2 node context_cnt:', self.length2_context_cnt) - # print(self.keys_values_wm_size_list) - # 输入self.keys_values_wm_list,输出为self.keys_values_wm + # 输入self.keys_values_wm_list,输出为self.keys_values_wm self.trim_and_pad_kv_cache() - # print(f'NOTE: in search node {self.keys_values_wm_size_list}') - # for k in range(num_passes): # assumption that there is only one action token. - # num_passes = 1 + self.num_observations_tokens - for k in range(2): # assumption that there is only one action token. + for k in range(2): # 假设每次只有一个动作token。 # action_token obs_token, ..., obs_token 1+1 - if k==0: + if k == 0: obs_embeddings_or_act_tokens = {'act_tokens': token} else: obs_embeddings_or_act_tokens = {'obs_embeddings': token} - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - # if k==0, action_token self.head_observations 1,...,0,1 - # output_sequence_list.append(outputs_wm.output_sequence) + outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, kvcache_independent=False) if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 + # 如果k==0,token是action_token,outputs_wm.logits_rewards 是有值的 reward = outputs_wm.logits_rewards # (B,) if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) + # 一共产生16个obs_token,每次产生一个 token = outputs_wm.logits_observations - # if len(token.shape) != 2: - # token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) if len(token.shape) != 3: token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) latent_state_list.append(token) - # output_sequence = torch.cat(output_sequence_list, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory + # 删除旧的self.latent_state以释放内存 del self.latent_state self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) self.update_cache(self.latent_state) - # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - - # if len(self.past_keys_values_cache) > self.max_cache_size: - # # TODO: lru_cache - # _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - # del popped_kv_cache # 测试不要这一行的显存情况 - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value def trim_and_pad_kv_cache(self): """ - This method trims and pads the key and value caches of the attention mechanism + This method trims and pads the key and value caches of the attention mechanism to a consistent size across all items in the batch, determined by the smallest cache size. """ - # Find the minimum size across all key-value sizes for padding/trimming + # 找到所有key-value尺寸中的最小尺寸,用于填充/修剪 min_size = min(self.keys_values_wm_size_list) - # Iterate over each layer of the transformer + # 遍历transformer的每一层 for layer in range(self.num_layers): - # Initialize lists to hold the trimmed and padded k and v caches + # 初始化列表来存储修剪和填充后的k和v缓存 kv_cache_k_list = [] kv_cache_v_list = [] - # Enumerate over the key-value pairs list + # 枚举key-value对列表 for idx, keys_values in enumerate(self.keys_values_wm_list): - # Retrieve the current layer's key and value caches + # 检索当前层的key和value缓存 k_cache = keys_values[layer]._k_cache._cache v_cache = keys_values[layer]._v_cache._cache - # Get the effective size of the current cache + # 获取当前缓存的有效尺寸 effective_size = self.keys_values_wm_size_list[idx] - # Calculate the size difference to trim + # 计算需要修剪的尺寸差异 trim_size = effective_size - min_size if effective_size > min_size else 0 - # If trimming is needed, remove 'trim_size' from the beginning of the cache + # 如果需要修剪,从缓存的开头移除'trim_size' if trim_size > 0: k_cache_trimmed = k_cache[:, :, trim_size:, :] v_cache_trimmed = v_cache[:, :, trim_size:, :] - # Pad the trimmed cache with zeros on the third dimension + # 在第三维上用零填充修剪后的缓存 k_cache_padded = F.pad(k_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) v_cache_padded = F.pad(v_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) else: k_cache_padded = k_cache v_cache_padded = v_cache - # Add the processed caches to the lists + # 将处理后的缓存添加到列表中 kv_cache_k_list.append(k_cache_padded) kv_cache_v_list.append(v_cache_padded) - # Stack the caches along the new dimension, and remove the extra dimension with squeeze() + # 沿新维度堆叠缓存,并用squeeze()移除额外维度 self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - # Update the cache size to the minimum size after trimming and padding + # 修剪和填充后,将缓存尺寸更新为最小尺寸 self.keys_values_wm._keys_values[layer]._k_cache._size = min_size self.keys_values_wm._keys_values[layer]._v_cache._size = min_size def update_cache(self, latent_state): - for i in range(latent_state.size(0)): # Iterate over each environment - state_single_env = latent_state[i] # Get the latent state for a single environment - quantized_state = state_single_env.detach().cpu().numpy() # Detach and move the state to CPU - cache_key = quantize_state(quantized_state) # Quantize state and compute its hash value as cache key + for i in range(latent_state.size(0)): # 遍历每个环境 + state_single_env = latent_state[i] # 获取单个环境的潜在状态 + quantized_state = state_single_env.detach().cpu().numpy() # 分离并将状态移至CPU + cache_key = quantize_state(quantized_state) # 量化状态并将其哈希值计算为缓存键 - # Copy keys and values from the global cache to a single environment cache + # 从全局缓存复制keys和values到单个环境缓存 for layer in range(self.num_layers): if self.keys_values_wm._keys_values[layer]._k_cache._size < self.config.max_tokens - 1: self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # Shape torch.Size([2, 100, 512]) @@ -903,7 +604,7 @@ def update_cache(self, latent_state): v_cache_trimmed = v_cache_current[:, 2:self.config.max_tokens - 1, :] # 沿第3维填充后2步 - padding_size = (0, 0, 0, 3) #F.pad的参数(0, 0, 0, 2)指定了在每个维度上的填充量。这些参数是按(左, 右, 上, 下)的顺序给出的,对于三维张量来说,分别对应于(维度2左侧, 维度2右侧, 维度1左侧, 维度1右侧)的填充。 + padding_size = (0, 0, 0, 3) # F.pad的参数(0, 0, 0, 2)指定了在每个维度上的填充量。这些参数是按(左, 右, 上, 下)的顺序给出的,对于三维张量来说,分别对应于(维度2左侧, 维度2右侧, 维度1左侧, 维度1右侧)的填充。 k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) # 更新单环境cache @@ -913,16 +614,15 @@ def update_cache(self, latent_state): self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.config.max_tokens - 3 self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.config.max_tokens - 3 - - # Compare and store the larger cache + # 比较并存储较大的缓存 if cache_key in self.past_keys_values_cache: existing_kvcache = self.past_keys_values_cache[cache_key] - # Check if there is a size difference between existing cache and new cache + # 检查现有缓存和新缓存之间是否存在大小差异 if self.keys_values_wm_single_env.size > existing_kvcache.size and self.keys_values_wm_single_env.size < self.config.max_tokens - 1: - # Only store if size is less than max_tokens - 1 to avoid reset + # 仅在大小小于 max_tokens - 1 时存储,以避免重置 self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) elif self.keys_values_wm_single_env.size < self.config.max_tokens - 1: - # Only store if size is less than max_tokens - 1 to avoid reset + # 仅在大小小于 max_tokens - 1 时存储,以避免重置 self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) def retrieve_or_generate_kvcache(self, latent_state, ready_env_num): @@ -932,20 +632,20 @@ def retrieve_or_generate_kvcache(self, latent_state, ready_env_num): """ for i in range(ready_env_num): self.total_query_count += 1 - state_single_env = latent_state[i] # Get the latent state for a single environment - cache_key = quantize_state(state_single_env) # Compute the hash value using the quantized state - # Retrieve the cached value if it exists + state_single_env = latent_state[i] # 获取单个环境的潜在状态 + cache_key = quantize_state(state_single_env) # 使用量化后的状态计算哈希值 + # 如果存在,检索缓存值 matched_value = self.past_keys_values_cache.get(cache_key) if matched_value is not None: - # If a matching value is found, add it to the list + # 如果找到匹配值,将其添加到列表中 self.hit_count += 1 - # Deepcopy is needed because the transformer's forward may modify matched_value in place + # 需要深度拷贝,因为transformer的forward可能会就地修改matched_value self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, self.device))) self.keys_values_wm_size_list.append(matched_value.size) else: - # If no match is found, use a zero reset + # 如果没有找到匹配值,使用零重置 self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) + self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env) self.keys_values_wm_list.append(self.keys_values_wm_single_env) self.keys_values_wm_size_list.append(1) @@ -960,7 +660,7 @@ def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysVal Returns: - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. """ - # Check if CUDA is available and select the first available CUDA device + # 检查CUDA是否可用并选择第一个可用的CUDA设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') for kv_cache in keys_values: @@ -968,7 +668,6 @@ def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysVal kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) return keys_values - # 计算显存使用量的函数 def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): total_memory_bytes = 0 @@ -985,7 +684,7 @@ def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): k_memory = torch.prod(torch.tensor(k_shape)) * 4 v_memory = torch.prod(torch.tensor(v_shape)) * 4 - # 累加keys和values缓存的内存 + # 累加keys和values缓存的内存 layer_memory = k_memory + v_memory total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 @@ -993,138 +692,125 @@ def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): total_memory_gb = total_memory_bytes / (1024 ** 3) return total_memory_gb - #@profile def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:测试这句的作用 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images + # 将观察编码为潜在状态表示 + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) + + # 注册梯度钩子,用于梯度缩放。这里的作用是将梯度缩小为原来的1/5,有助于训练的稳定性。 + # 但是否必要取决于具体问题,需要通过实验来验证。 + # obs_embeddings.register_hook(lambda grad: grad * 1/5) + + # 从潜在状态表示重建观察 reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - - # Calculate the reconstruction loss - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs - - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - + # 计算重建损失和感知损失 + latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) + perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) + + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # NOTE: for stack=4 + # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # NOTE: for stack=4 + + # 动作tokens act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) # 联合更新性能更好 - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - + # 前向传播,得到预测的观察、奖励和策略等 + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}) + + # 为了训练稳定性,使用target_tokenizer计算真实的下一个潜在状态表示 with torch.no_grad(): - # 为了训练稳定性,world_model预测的next_latent_state是用target world_model 产生的 - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) + traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) + # 计算观察、奖励和结束标签 labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - + batch['ends'], + batch['mask_padding']) + + # 重塑观察的logits和labels logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + # 计算观察的预测损失。这里提供了两种选择:MSE和Group KL + if self.predict_latent_loss_type == 'mse': + # MSE损失,直接比较logits和labels + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) + elif self.predict_latent_loss_type == 'group_kl': + # Group KL损失,将特征分组,然后计算组内的KL散度 + batch_size, num_features = logits_observations.shape + group_size = 8 # TODO + num_groups = num_features // group_size + + logits_reshaped = logits_observations.reshape(batch_size, num_groups, group_size) + labels_reshaped = labels_observations.reshape(batch_size, num_groups, group_size) + + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded).mean() + + # 计算策略和价值的标签 labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) + batch['target_policy'], + batch['mask_padding']) + # 计算奖励、策略和价值的损失 loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss, orig_policy_loss=orig_policy_loss, policy_entropy=policy_entropy) - - # def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # # Assume outputs.logits_rewards and labels are your predictions and targets - # # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - # if element == 'rewards': - # logits = outputs.logits_rewards - # elif element == 'policy': - # logits = outputs.logits_policy - # elif element == 'value': - # logits = outputs.logits_value - - # # Reshape your tensors - # logits_rewards = rearrange(logits, 'b t e -> (b t) e') - # labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # # Reshape your mask. True means valid data. - # mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - # loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - # loss_rewards = (loss_rewards * mask_padding).mean() - - # return loss_rewards + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + loss_obs=loss_obs, + loss_rewards=loss_rewards, + loss_value=loss_value, + loss_policy=loss_policy, + latent_recon_loss=latent_recon_loss, + perceptual_loss=perceptual_loss, + orig_policy_loss=orig_policy_loss, + policy_entropy=policy_entropy + ) def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs is an object with logits attributes for 'rewards', 'policy', and 'value' - # And labels is a tensor with targets to compare against. The batch is a dictionary - # with a mask to indicate valid timesteps. + # 假设outputs是一个具有'rewards'、'policy'和'value'的logits属性的对象 + # labels是一个与之比较的目标张量。batch是一个带有指示有效时间步的mask的字典。 logits = getattr(outputs, f'logits_{element}') - # Reshape your tensors + # 重塑你的张量 logits = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [batch, time, dim] + labels = labels.reshape(-1, labels.shape[-1]) # 假设labels最初的shape是 [batch, time, dim] - # Reshape your mask. True means valid data. + # 重塑你的mask。True表示有效数据。 mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - # Compute the cross entropy loss + # 计算交叉熵损失 loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) loss = (loss * mask_padding).mean() if element == 'policy': - # Calculate policy entropy loss + # 计算策略熵损失 policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) - # Combine the losses with the specified weight + # 用指定的权重组合损失 combined_loss = loss - self.policy_entropy_weight * policy_entropy return combined_loss, loss, policy_entropy return loss def compute_policy_entropy_loss(self, logits, mask): - # Calculate the entropy of the policy + # 计算策略的熵 probs = torch.softmax(logits, dim=1) log_probs = torch.log_softmax(logits, dim=1) entropy = -(probs * log_probs).sum(1) - # Apply the mask and return the mean entropy loss + # 应用mask并返回平均熵损失 entropy_loss = (entropy * mask).mean() return entropy_loss def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert torch.all(ends.sum(dim=1) <= 1) # 每个序列样本最多只有1个done mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) @@ -1132,8 +818,7 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: mask_fill = torch.logical_not(mask_padding) mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) @@ -1143,35 +828,12 @@ def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, ta labels_value = target_value.masked_fill(mask_fill_value, -100) return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - def render_img(self, obs: int, rec_img: int): import torch from PIL import Image import matplotlib.pyplot as plt - # 假设batch是一个字典,其中包含了observations键, + # 假设batch是一个字典,其中包含了observations键, # 并且它的形状是torch.Size([B, N, C, H, W]) # batch_observations = batch_for_gpt['observations'] # batch_observations = batch['observations'] @@ -1182,11 +844,9 @@ def render_img(self, obs: int, rec_img: int): # batch_observations = x.unsqueeze(0) # batch_observations = reconstructions.unsqueeze(0) - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - # 分隔条的宽度(可以根据需要调整) + # 分隔条的宽度(可以根据需要调整) separator_width = 2 # 遍历每个样本 @@ -1194,10 +854,10 @@ def render_img(self, obs: int, rec_img: int): # 提取当前样本中的所有帧 frames = batch_observations[i] - # 计算拼接图像的总宽度(包括分隔条) + # 计算拼接图像的总宽度(包括分隔条) total_width = N * W + (N - 1) * separator_width - # 创建一个新的图像,其中包含分隔条 + # 创建一个新的图像,其中包含分隔条 concat_image = Image.new('RGB', (total_width, H), color='black') # 拼接每一帧及分隔条 @@ -1216,4 +876,7 @@ def render_img(self, obs: int, rec_img: int): plt.show() # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file + concat_image.save(f'sample_{i+1}.png') + + def __repr__(self) -> str: + return "world_model" \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_bkp20240316.py b/lzero/model/gpt_models/world_model_bkp20240316.py new file mode 100644 index 000000000..e833b4335 --- /dev/null +++ b/lzero/model/gpt_models/world_model_bkp20240316.py @@ -0,0 +1,1157 @@ +import copy +from dataclasses import dataclass +import random +from typing import Any, Optional, Tuple +from typing import List, Optional, Union +import logging +# 设置日志记录级别为DEBUG +logging.getLogger().setLevel(logging.DEBUG) +from PIL import Image +from einops import rearrange +from einops import rearrange +import gym +from joblib import hash +import numpy as np +import torch +import torch +from torch.distributions.categorical import Categorical +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from .kv_caching import KeysValues +from .slicer import Embedder, Head, ActEmbedder +from .tokenizer import Tokenizer +from .transformer import Transformer, TransformerConfig +from .utils import LossWithIntermediateLosses, init_weights +from ding.torch_utils import to_device + + +# from memory_profiler import profile +from line_profiler import line_profiler +import hashlib + +class SimNorm(nn.Module): + """ + Simplicial normalization. + Adapted from https://arxiv.org/abs/2204.00616. + """ + + def __init__(self, simnorm_dim): + super().__init__() + self.dim = simnorm_dim + + def forward(self, x): + shp = x.shape + # Ensure that there is at least one simplex to normalize across. + if shp[1] != 0: + x = x.view(*shp[:-1], -1, self.dim) + x = F.softmax(x, dim=-1) + return x.view(*shp) + else: + return x + + def __repr__(self): + return f"SimNorm(dim={self.dim})" + +# def quantize_state(state, num_buckets=1000): +def quantize_state(state, num_buckets=15): +# def quantize_state(state, num_buckets=10): + """ + 量化状态向量。 + 参数: + state: 要量化的状态向量。 + num_buckets: 量化的桶数。 + 返回: + 量化后的状态向量的哈希值。 + """ + # 使用np.digitize将状态向量的每个维度值映射到num_buckets个桶中 + quantized_state = np.digitize(state, bins=np.linspace(0, 1, num=num_buckets)) + # 使用更稳定的哈希函数 + quantized_state_bytes = quantized_state.tobytes() + hash_object = hashlib.sha256(quantized_state_bytes) + return hash_object.hexdigest() + + + +@dataclass +class WorldModelOutput: + output_sequence: torch.FloatTensor + logits_observations: torch.FloatTensor + logits_rewards: torch.FloatTensor + logits_ends: torch.FloatTensor + + logits_policy: torch.FloatTensor + logits_value: torch.FloatTensor + + +class WorldModel(nn.Module): + def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: + super().__init__() + + # config.max_tokens = int(2*50) # TODO + + self.tokenizer = tokenizer + self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size + self.config = config + self.policy_entropy_weight = config.policy_entropy_weight + + self.predict_latent_loss_type = config.predict_latent_loss_type + + self.transformer = Transformer(config) + # self.num_observations_tokens = 16 + self.num_observations_tokens = config.tokens_per_block -1 + + self.latent_recon_loss_weight = config.latent_recon_loss_weight + self.perceptual_loss_weight = config.perceptual_loss_weight + + self.device = config.device + self.support_size = config.support_size + self.action_shape = config.action_shape + self.max_cache_size = config.max_cache_size + self.env_num = config.env_num + self.num_layers = config.num_layers + self.sim_norm = SimNorm(simnorm_dim=8) + + all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) + all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 + + act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 + act_tokens_pattern[-1] = 1 # 0,...,0,1 + latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 + + # current latent state's policy value + value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) + value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] + + # next latent state's policy value + # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) + # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] + + obs_per_embdding_dim=config.embed_dim + # TODO + self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) + + + self.embedder = Embedder( + max_blocks=config.max_blocks, + block_masks=[act_tokens_pattern, latent_state_pattern], + embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) + ) + + self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) + self.act_embedding_table.weight.requires_grad = False # TODO: 测试效果 + + self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 + + self.head_rewards = Head( + max_blocks=config.max_blocks, + block_mask=act_tokens_pattern, # 0,...,0,1 + head_module=nn.Sequential( + nn.Linear(config.embed_dim, config.embed_dim), + nn.ReLU(), + nn.Linear(config.embed_dim, self.support_size) + ) + ) + + self.head_ends = Head( + max_blocks=config.max_blocks, + block_mask=act_tokens_pattern, # 0,...,0,1 + head_module=nn.Sequential( + nn.Linear(config.embed_dim, config.embed_dim), + nn.ReLU(), + nn.Linear(config.embed_dim, 2) + ) + ) + + self.head_observations_for_root = Head( # TODO + max_blocks=config.max_blocks, + block_mask=latent_state_pattern, # 1,...,1,0 + head_module=nn.Sequential( + nn.Linear(config.embed_dim, config.embed_dim), + nn.BatchNorm1d(config.embed_dim), + nn.ReLU(), + nn.Linear(config.embed_dim, obs_per_embdding_dim) + ) + ) + + ###### TODO: 2层的性能, LeakyReLU->GELU ###### + self.head_observations = Head( # TODO + max_blocks=config.max_blocks, + block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 + head_module=nn.Sequential( + nn.Linear(config.embed_dim, config.embed_dim), + nn.GELU(), + nn.Linear(config.embed_dim, self.obs_per_embdding_dim), + # nn.Sigmoid(), # 这里添加Sigmoid函数 TODO + self.sim_norm, + ) + ) + self.head_policy = Head( + max_blocks=config.max_blocks, + block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] + head_module=nn.Sequential( # (8, 5, 128) + nn.Linear(config.embed_dim, config.embed_dim), + nn.GELU(), + nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape + ) + ) + self.head_value = Head( + max_blocks=config.max_blocks, + block_mask=value_policy_tokens_pattern, + head_module=nn.Sequential( + nn.Linear(config.embed_dim, config.embed_dim), + nn.GELU(), + nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape + ) + ) + + self.apply(init_weights) + + last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. + if last_linear_layer_init_zero: + for _, layer in enumerate(reversed(self.head_value.head_module)): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + for _, layer in enumerate(reversed(self.head_rewards.head_module)): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + for _, layer in enumerate(reversed(self.head_observations.head_module)): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + # layer.weight.data.fill_(0.5) # TODO:bug + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + + + import collections + self.past_keys_values_cache = collections.OrderedDict() + self.past_policy_value_cache = collections.OrderedDict() + + # TODO: Transformer更新后应该清除缓存 + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.config.max_tokens) + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + + if self.num_observations_tokens==16: # k=16 + self.projection_input_dim = 128 + elif self.num_observations_tokens==1: # K=1 + self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO + # self.projection_input_dim = 256 # for cartpole + + + self.hit_count = 0 + self.total_query_count = 0 + self.length3_context_cnt = 0 + self.length2_context_cnt = 0 + self.root_hit_cnt = 0 + self.root_total_query_cnt = 0 + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) + + + def __repr__(self) -> str: + return "world_model" + + #@profile + def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, + is_root=False, kvcache_independent=False) -> WorldModelOutput: + + if kvcache_independent: + prev_steps = 0 if past_keys_values is None else [past_kv.size for past_kv in past_keys_values] + prev_steps = torch.tensor(prev_steps, device=self.device) + # 我们需要为每个样本生成一个序列的步骤indices,然后获取它们的位置嵌入 + # 首先扩展prev_steps至(num_steps, batch_size),这里num_steps=1 + # prev_steps = prev_steps.unsqueeze(0) + + else: + prev_steps = 0 if past_keys_values is None else past_keys_values.size + # print(f'prev_steps:{prev_steps}') + + if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): + obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] + if len(obs_embeddings.shape)==2: + obs_embeddings = obs_embeddings.unsqueeze(1) + num_steps = obs_embeddings.size(1) # (B, T, E) + if kvcache_independent: + # 生成每个样本的步骤indices + # steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device).unsqueeze(1) + steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device) + + # 步骤indices需要被reshape成一维,以便于embedding操作 + # steps_indices = steps_indices.view(-1) + # 获取位置嵌入 + position_embeddings = self.pos_emb(steps_indices) + # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) + position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) + # 现在我们可以将位置嵌入加到obs_embeddings上了 + sequences = obs_embeddings + position_embeddings + else: + sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) + elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): + act_tokens = obs_embeddings_or_act_tokens['act_tokens'] + if len(act_tokens.shape)==3: + act_tokens = act_tokens.squeeze(1) + num_steps = act_tokens.size(1) # (B, T) + act_embeddings = self.act_embedding_table(act_tokens) + + if kvcache_independent: + # 生成每个样本的步骤indices + # steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device).unsqueeze(1) + steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device) + # 步骤indices需要被reshape成一维,以便于embedding操作 + # steps_indices = steps_indices.view(-1) + # 获取位置嵌入 + position_embeddings = self.pos_emb(steps_indices) + # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) + position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) + # 现在我们可以将位置嵌入加到obs_embeddings上了 + sequences = act_embeddings + position_embeddings + else: + sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) + else: + obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) + obs_embeddings, act_tokens = obs_embeddings_and_act_tokens + if len(obs_embeddings.shape)==3: # for batch compute loss + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) + + num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) + + # Generate action embeddings from action tokens + # (B, L, 1) -> (B, L, 1, E) + act_embeddings = self.act_embedding_table(act_tokens) + + # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) + # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 + + B, L, K, E = obs_embeddings.size() + + # 初始化一个新的空tensor,用于存放最终的拼接结果 + obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) + + # 对每一个序列长度L进行循环 + for i in range(L): + # 获取当前时刻的obs和act embeddings + obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) + act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 + + # 交替拼接obs和act + obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) + + # 将结果填充到最终的tensor中 + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + # Add positional embeddings + sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) + + + # print('transformer forward begin') 函数里面更新了update past_keys_values + if kvcache_independent: + x = [] + for k, past_kv in enumerate(past_keys_values): + x.append(self.transformer(sequences[k].unsqueeze(0), past_kv)) + x = torch.cat(x, dim=0) + # TODO: 在collect时,是一步一步的 obs act 传入的 + # prev_steps = prev_steps//1 + else: + x = self.transformer(sequences, past_keys_values) + + # print('transformer forward done') + + if is_root: + logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) + else: + # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + # logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) + + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + + # TODO: root reward value + return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) + + + @torch.no_grad() + #@profile + def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(obs_act_dict, dict): + observations = obs_act_dict['obs'] + buffer_action = obs_act_dict['action'] + current_obs = obs_act_dict['current_obs'] + else: + observations = obs_act_dict + buffer_action = None + current_obs = None + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) + + if current_obs is not None: + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(current_obs, should_preprocess=True) # (B, C, H, W) -> (B, K, E) + self.latent_state = current_obs_embeddings + outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action, current_obs_embeddings) + else: + self.latent_state = obs_embeddings + outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action, None) + + return outputs_wm, self.latent_state + + @torch.no_grad() + #@profile + def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None, current_obs_embeddings=None) -> torch.FloatTensor: + n, num_observations_tokens, _ = latent_state.shape + if n <= self.env_num: + if buffer_action is None: + # MCTS root节点: 需要准确的估计 value, policy_logits, 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 + self.keys_values_wm = self.transformer.generate_empty_keys_values(n, max_tokens=self.config.max_tokens) + outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) + self.keys_values_wm_size_list = [1 for i in range(n)] + + # 复制单个环境对应的 keys_values_wm 并存储 + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) + for i in range(latent_state.size(0)): # 遍历每个环境 + state_single_env = latent_state[i] # 获取单个环境的 latent state + # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 + quantized_state = state_single_env.detach().cpu().numpy() + cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 + for layer in range(self.num_layers): + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) + self.root_total_query_cnt += 1 + if cache_key not in self.past_keys_values_cache: + self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) + else: + self.root_hit_cnt += 1 + root_hit_ratio = self.root_hit_cnt / self.root_total_query_cnt + print('root_total_query_cnt:', self.root_total_query_cnt) + print(f'root_hit_ratio:{root_hit_ratio}') + print(f'root_hit_cnt:{self.root_hit_cnt}') + print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}') + if self.past_keys_values_cache[cache_key].size>1: + print(f'=='*20) + print(f'NOTE: root_hit find size > 1') + print(f'=='*20) + elif current_obs_embeddings is not None: + + if max(buffer_action) == -1: + # first step in one episode + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm, is_root=False) + + # 复制单个环境对应的 keys_values_wm 并存储 + self.update_cache(current_obs_embeddings) + else: + # self.retrieve_or_generate_kvcache(latent_state, current_obs.shape[0]) + # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 + # ready_env_num = latent_state.shape[0] + ready_env_num = current_obs_embeddings.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + for i in range(ready_env_num): + state_single_env = latent_state[i] # 获取单个环境的 latent state + quantized_state = state_single_env.detach().cpu().numpy() + cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 + matched_value = self.past_keys_values_cache.get(cache_key) # 检索缓存值 + self.root_total_query_cnt += 1 + if matched_value is not None: + # 如果找到匹配的值,将其添加到列表中 + self.root_hit_cnt += 1 + if self.root_total_query_cnt>0 and self.root_total_query_cnt%1000==0: + root_hit_ratio = self.root_hit_cnt / self.root_total_query_cnt + print('root_total_query_cnt:', self.root_total_query_cnt) + print(f'root_hit_ratio:{root_hit_ratio}') + print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}') + if self.past_keys_values_cache[cache_key].size>=7: + print(f'=='*20) + print(f'NOTE: root_hit find size >= 7') + print(f'=='*20) + # 这里需要deepcopy因为在transformer的forward中会原地修改matched_value + self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # use zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) + # outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + # print(f'NOTE: root {self.keys_values_wm_size_list}') + # print(f'=='*20) + # 输入self.keys_values_wm_list,输出为self.keys_values_wm + self.trim_and_pad_kv_cache() + + buffer_action = buffer_action[:ready_env_num] + buffer_action = torch.from_numpy(np.array(buffer_action)).to(latent_state.device) + act_tokens = buffer_action.unsqueeze(-1) + # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (latent_state, act_tokens)}, past_keys_values=self.keys_values_wm, is_root=False) + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, is_root=False) + + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm, is_root=False) + + # 复制单个环境对应的 keys_values_wm 并存储 + self.update_cache(current_obs_embeddings) + + elif n == int(256): + # TODO: n=256 means train tokenizer, 不需要计算target value + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) + # print('init inference: not find matched_value! reset!') + outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) + elif n > self.env_num and n != int(256) and buffer_action is not None: + # train时计算target value + # TODO: n=256 means train tokenizer + # [192, 16, 64] -> [32, 6, 16, 64] + latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + + # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) + latent_state = latent_state[:, :-1, :] + # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) + buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) + act_tokens = rearrange(buffer_action, 'b l -> b l 1') + + # # 选择每个样本的最后一步 + last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 + # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps + act_tokens = torch.cat((act_tokens, last_steps), dim=1) + + # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (latent_state, act_tokens)}, is_root=False) + + # 选择每个样本的最后一步 + last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 + # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (30,21) = (B*6, 21) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + return outputs_wm + + @torch.no_grad() + # #@profile + def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: + if isinstance(obs_act_dict, dict): + # obs_act_dict = {'obs':obs, 'action':action_batch} + observations = obs_act_dict['obs'] + buffer_action = obs_act_dict['action'] + else: + observations = obs_act_dict + buffer_action = None + + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) + outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) + self.latent_state = obs_embeddings + + return outputs_wm, self.latent_state + + + @torch.no_grad() + # #@profile + def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: + n, num_observations_tokens, _ = latent_state.shape + if n <= self.env_num: + # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) + outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False + elif n == int(256): + # TODO: n=256 means train tokenizer, 不需要计算target value + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) + # print('init inference: not find matched_value! reset!') + outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False + elif n > self.env_num and n != int(256) and buffer_action is not None: + # TODO: n=256 means train tokenizer + # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value + # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) + # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) + + # [192, 16] -> [32, 6, 16] + # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 + + # [192, 16, 64] -> [32, 6, 16, 64] + latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + + # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) + latent_state = latent_state[:, :-1, :] + # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) + buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) + act_tokens = rearrange(buffer_action, 'b l -> b l 1') + + # # 选择每个样本的最后一步 + last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 + # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps + act_tokens = torch.cat((act_tokens, last_steps), dim=1) + + # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 + obs_embeddings = latent_state + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) + + # 选择每个样本的最后一步 + last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 + # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (30,21) = (B*6, 21) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + + # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) + return outputs_wm + + + @torch.no_grad() + # #@profile + def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor, reset_indices=None) -> torch.FloatTensor: + n, num_observations_tokens, _ = latent_state.shape + assert num_observations_tokens == self.num_observations_tokens + # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) + if reset_indices is None: + self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] + else: + for i in reset_indices: + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) + outputs_wm = self.forward({'obs_embeddings': latent_state[i].unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False, kvcache_independent=False) + self.keys_values_wm_list[i] = self.keys_values_wm_single_env + self.keys_values_wm_size_list[i] = 1 + return None + + @torch.no_grad() + #@profile + def forward_initial_inference(self, obs_act_dict): + outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # root节点也有context + + # # 计算L1范数 + # l1_norm = torch.norm(latent_state.squeeze(1), p=1, dim=1) # 计算每个向量的L1范数 + # average_l1_norm = torch.mean(l1_norm) # 计算所有向量L1范数的平均值 + # # 计算L2范数 + # l2_norm = torch.norm(latent_state.squeeze(1), p=2, dim=1) # 计算每个向量的L2范数 + # average_l2_norm = torch.mean(l2_norm) # 计算所有向量L2范数的平均值 + # 打印结果 + # print(f'Average L1 norm: {average_l1_norm}') + # print(f'Average L2 norm: {average_l2_norm}') + # Average L1 norm: 96.0 + # Average L2 norm: 8.440537452697754 + + return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value + + """ + 假设env_num=8 + 8个环境的kv_cache单独存储与寻找,都存储在一个dict中,在recurrent_inference时, + 由于不同环境找到的kv_cache的size不同,先根据最大size对kv_cache在前部补零,然后组成batch_size的kv_cache + 其内部也是通过batch执行transformer forward的推理 + """ + + @torch.no_grad() + #@profile + def forward_recurrent_inference(self, state_action_history): + # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 + # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 + # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 + # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? + + latest_state, action = state_action_history[-1] + + # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 + ready_env_num = latest_state.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + self.retrieve_or_generate_kvcache(latest_state, ready_env_num) + + latent_state_list = [] + # output_sequence_list, latent_state_list = [], [] + + # reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens] + # self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices) + + token = action.reshape(-1, 1) + + # print(self.keys_values_wm_size_list) + # 获取self.keys_values_wm_size_list的最小值min_size + min_size = min(self.keys_values_wm_size_list) + if min_size >= self.config.max_tokens - 5: + self.length3_context_cnt += len(self.keys_values_wm_size_list) + if min_size >= 3: + self.length2_context_cnt += len(self.keys_values_wm_size_list) + # if self.total_query_count>0 and self.total_query_count%1==0: + if self.total_query_count>0 and self.total_query_count%10000==0: + self.hit_freq = self.hit_count/(self.total_query_count) + # print('hit_freq:', self.hit_freq) + # print('hit_count:', self.hit_count) + print('total_query_count:', self.total_query_count) + # 如果总查询次数大于0,计算并打印cnt的比率 + length3_context_cnt_ratio = self.length3_context_cnt / self.total_query_count + print('>=3 node context_cnt:', self.length3_context_cnt) + print('>=3 node context_cnt_ratio:', length3_context_cnt_ratio) + length2_context_cnt_ratio = self.length2_context_cnt / self.total_query_count + print('>=2 node context_cnt_ratio:', length2_context_cnt_ratio) + print('>=2 node context_cnt:', self.length2_context_cnt) + # print(self.keys_values_wm_size_list) + + # 输入self.keys_values_wm_list,输出为self.keys_values_wm + self.trim_and_pad_kv_cache() + # print(f'NOTE: in search node {self.keys_values_wm_size_list}') + + # for k in range(num_passes): # assumption that there is only one action token. + # num_passes = 1 + self.num_observations_tokens + for k in range(2): # assumption that there is only one action token. + # action_token obs_token, ..., obs_token 1+1 + if k==0: + obs_embeddings_or_act_tokens = {'act_tokens': token} + else: + obs_embeddings_or_act_tokens = {'obs_embeddings': token} + + outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) + # if k==0, action_token self.head_observations 1,...,0,1 + # output_sequence_list.append(outputs_wm.output_sequence) + + if k == 0: + # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 + reward = outputs_wm.logits_rewards # (B,) + + if k < self.num_observations_tokens: + # 一共产生16个obs_token,每次产生一个 + # TODO: sample or argmax + # token = Categorical(logits=outputs_wm.logits_observations).sample() + # Use argmax to select the most likely token + # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) + token = outputs_wm.logits_observations + # if len(token.shape) != 2: + # token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) + if len(token.shape) != 3: + token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) + latent_state_list.append(token) + + # output_sequence = torch.cat(output_sequence_list, dim=1) # (B, 1 + K, E) + # Before updating self.latent_state, delete the old one to free memory + del self.latent_state + self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) + + self.update_cache(self.latent_state) + # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy + + # if len(self.past_keys_values_cache) > self.max_cache_size: + # # TODO: lru_cache + # _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) + # del popped_kv_cache # 测试不要这一行的显存情况 + + # Example usage: + # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` + # and `num_layers` is the number of transformer layers + # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) + # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') + + return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value + + def trim_and_pad_kv_cache(self): + """ + This method trims and pads the key and value caches of the attention mechanism + to a consistent size across all items in the batch, determined by the smallest cache size. + """ + + # Find the minimum size across all key-value sizes for padding/trimming + min_size = min(self.keys_values_wm_size_list) + + # Iterate over each layer of the transformer + for layer in range(self.num_layers): + # Initialize lists to hold the trimmed and padded k and v caches + kv_cache_k_list = [] + kv_cache_v_list = [] + + # Enumerate over the key-value pairs list + for idx, keys_values in enumerate(self.keys_values_wm_list): + # Retrieve the current layer's key and value caches + k_cache = keys_values[layer]._k_cache._cache + v_cache = keys_values[layer]._v_cache._cache + + # Get the effective size of the current cache + effective_size = self.keys_values_wm_size_list[idx] + # Calculate the size difference to trim + trim_size = effective_size - min_size if effective_size > min_size else 0 + + # If trimming is needed, remove 'trim_size' from the beginning of the cache + if trim_size > 0: + k_cache_trimmed = k_cache[:, :, trim_size:, :] + v_cache_trimmed = v_cache[:, :, trim_size:, :] + # Pad the trimmed cache with zeros on the third dimension + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) + else: + k_cache_padded = k_cache + v_cache_padded = v_cache + + # Add the processed caches to the lists + kv_cache_k_list.append(k_cache_padded) + kv_cache_v_list.append(v_cache_padded) + + # Stack the caches along the new dimension, and remove the extra dimension with squeeze() + self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) + self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) + + # Update the cache size to the minimum size after trimming and padding + self.keys_values_wm._keys_values[layer]._k_cache._size = min_size + self.keys_values_wm._keys_values[layer]._v_cache._size = min_size + + def update_cache(self, latent_state): + for i in range(latent_state.size(0)): # Iterate over each environment + state_single_env = latent_state[i] # Get the latent state for a single environment + quantized_state = state_single_env.detach().cpu().numpy() # Detach and move the state to CPU + cache_key = quantize_state(quantized_state) # Quantize state and compute its hash value as cache key + + # Copy keys and values from the global cache to a single environment cache + for layer in range(self.num_layers): + if self.keys_values_wm._keys_values[layer]._k_cache._size < self.config.max_tokens - 1: + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # Shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + elif self.keys_values_wm._keys_values[layer]._k_cache._size == self.config.max_tokens - 1: + # 裁剪和填充逻辑 + # 假设cache的维度是 [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + # 移除前2步并保留最近的max_tokens - 3步 + k_cache_trimmed = k_cache_current[:, 2:self.config.max_tokens - 1, :] + v_cache_trimmed = v_cache_current[:, 2:self.config.max_tokens - 1, :] + + # 沿第3维填充后2步 + padding_size = (0, 0, 0, 3) #F.pad的参数(0, 0, 0, 2)指定了在每个维度上的填充量。这些参数是按(左, 右, 上, 下)的顺序给出的,对于三维张量来说,分别对应于(维度2左侧, 维度2右侧, 维度1左侧, 维度1右侧)的填充。 + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # 更新单环境cache + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.config.max_tokens - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.config.max_tokens - 3 + + + # Compare and store the larger cache + if cache_key in self.past_keys_values_cache: + existing_kvcache = self.past_keys_values_cache[cache_key] + # Check if there is a size difference between existing cache and new cache + if self.keys_values_wm_single_env.size > existing_kvcache.size and self.keys_values_wm_single_env.size < self.config.max_tokens - 1: + # Only store if size is less than max_tokens - 1 to avoid reset + self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) + elif self.keys_values_wm_single_env.size < self.config.max_tokens - 1: + # Only store if size is less than max_tokens - 1 to avoid reset + self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) + + def retrieve_or_generate_kvcache(self, latent_state, ready_env_num): + """ + This method iterates over the environments, retrieves a matching cache if available, + or generates a new one otherwise. It updates the lists with the keys_values caches and their sizes. + """ + for i in range(ready_env_num): + self.total_query_count += 1 + state_single_env = latent_state[i] # Get the latent state for a single environment + cache_key = quantize_state(state_single_env) # Compute the hash value using the quantized state + # Retrieve the cached value if it exists + matched_value = self.past_keys_values_cache.get(cache_key) + if matched_value is not None: + # If a matching value is found, add it to the list + self.hit_count += 1 + # Deepcopy is needed because the transformer's forward may modify matched_value in place + self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, self.device))) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # If no match is found, use a zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) + self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: + """ + Transfer all KVCache objects within the KeysValues object to a certain device. + + Arguments: + - keys_values (KeysValues): The KeysValues object to be transferred. + - device (str): The device to transfer to. + + Returns: + - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. + """ + # Check if CUDA is available and select the first available CUDA device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + for kv_cache in keys_values: + kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) + kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) + return keys_values + + + # 计算显存使用量的函数 + def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): + total_memory_bytes = 0 + + # 遍历OrderedDict中所有的KeysValues实例 + for kv_instance in past_keys_values_cache.values(): + num_layers = len(kv_instance) # 获取层数 + for layer in range(num_layers): + kv_cache = kv_instance[layer] + k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 + v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 + + # 计算元素个数并乘以每个元素的字节数 + k_memory = torch.prod(torch.tensor(k_shape)) * 4 + v_memory = torch.prod(torch.tensor(v_shape)) * 4 + + # 累加keys和values缓存的内存 + layer_memory = k_memory + v_memory + total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 + + # 将总内存从字节转换为吉字节 + total_memory_gb = total_memory_bytes / (1024 ** 3) + return total_memory_gb + + #@profile + def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: + # NOTE: 这里是需要梯度的 + #with torch.no_grad(): # TODO: 非常重要 + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) + obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:测试 gard scale 的作用 + + # Assume that 'cont_embeddings' and 'original_images' are available from prior code + # Decode the embeddings to reconstruct the images + reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # Calculate the reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 + # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs + + latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 + perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 + + # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) + # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) + + latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) + + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + # TODO: 是否只用重建损失更新表征网络 非常重要 + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) # 联合更新性能更好 + # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) + + with torch.no_grad(): + # 为了训练稳定性,world_model预测的next_latent_state是用target world_model 产生的 + traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) + + labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], + batch['ends'], + batch['mask_padding']) + # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], + # batch['ends'], + # batch['mask_padding']) + + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: + + if self.predict_latent_loss_type == 'mse': + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) + elif self.predict_latent_loss_type == 'group_kl': + # 将 logits_observations 和 labels_observations 按照组的大小进行重塑 + # batch_size = logits_observations.shape[0] + # group_size = 8 + # num_groups = logits_observations.shape[1] // group_size + + # logits_reshaped = logits_observations.view(batch_size, num_groups, group_size) + # labels_reshaped = labels_observations.view(batch_size, num_groups, group_size) + + # 计算每个组内的KL散度损失,并求平均 + # loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(-1).mean(1) + + # 将 logits_observations 和 labels_observations 按照组的大小进行重塑 + batch_size, num_features = logits_observations.shape + group_size = 8 + num_groups = num_features // group_size + + logits_reshaped = logits_observations.reshape(batch_size, num_groups, group_size) + labels_reshaped = labels_observations.reshape(batch_size, num_groups, group_size) + + # 计算每个组内的KL散度损失,并求平均 + # loss_obs = F.kl_div(F.log_softmax(logits_reshaped, dim=-1), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) # bug + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + + + + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: + # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) + + # 应用mask到loss_obs + loss_obs = (loss_obs * mask_padding_expanded).mean(-1) + labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) + + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') + loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, + loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss, orig_policy_loss=orig_policy_loss, policy_entropy=policy_entropy) + + + def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs is an object with logits attributes for 'rewards', 'policy', and 'value' + # And labels is a tensor with targets to compare against. The batch is a dictionary + # with a mask to indicate valid timesteps. + + logits = getattr(outputs, f'logits_{element}') + + # Reshape your tensors + logits = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [batch, time, dim] + + # Reshape your mask. True means valid data. + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') + + # Compute the cross entropy loss + loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) + loss = (loss * mask_padding).mean() + + if element == 'policy': + # Calculate policy entropy loss + policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) + # Combine the losses with the specified weight + combined_loss = loss - self.policy_entropy_weight * policy_entropy + return combined_loss, loss, policy_entropy + + return loss + + def compute_policy_entropy_loss(self, logits, mask): + # Calculate the entropy of the policy + probs = torch.softmax(logits, dim=1) + log_probs = torch.log_softmax(logits, dim=1) + entropy = -(probs * log_probs).sum(1) + # Apply the mask and return the mean entropy loss + entropy_loss = (entropy * mask).mean() + return entropy_loss + + def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor]: + assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done + mask_fill = torch.logical_not(mask_padding) + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim + + + # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) + + mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) + labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) + + labels_ends = ends.masked_fill(mask_fill, -100) + return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + + def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[ + torch.Tensor, torch.Tensor]: + + mask_fill = torch.logical_not(mask_padding) + mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) + labels_policy = target_policy.masked_fill(mask_fill_policy, -100) + + mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) + labels_value = target_value.masked_fill(mask_fill_value, -100) + return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) + + + def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + Overview: + consistency loss function: the negative cosine similarity. + Arguments: + - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) + - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) + Returns: + (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. + The cosine similarity always belongs to the interval [-1, 1]. + For example, two proportional vectors have a cosine similarity of 1, + two orthogonal vectors have a similarity of 0, + and two opposite vectors have a similarity of -1. + -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. + Reference: + https://en.wikipedia.org/wiki/Cosine_similarity + """ + x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) + x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) + return -(x1 * x2).sum(dim=1) + + + def render_img(self, obs: int, rec_img: int): + import torch + from PIL import Image + import matplotlib.pyplot as plt + + # 假设batch是一个字典,其中包含了observations键, + # 并且它的形状是torch.Size([B, N, C, H, W]) + # batch_observations = batch_for_gpt['observations'] + # batch_observations = batch['observations'] + batch_observations = obs.unsqueeze(0) + # batch_observations = rec_img.unsqueeze(0) + + # batch_observations = observations.unsqueeze(0) + # batch_observations = x.unsqueeze(0) + # batch_observations = reconstructions.unsqueeze(0) + + + + B, N, C, H, W = batch_observations.shape # 自动检测维度 + + # 分隔条的宽度(可以根据需要调整) + separator_width = 2 + + # 遍历每个样本 + for i in range(B): + # 提取当前样本中的所有帧 + frames = batch_observations[i] + + # 计算拼接图像的总宽度(包括分隔条) + total_width = N * W + (N - 1) * separator_width + + # 创建一个新的图像,其中包含分隔条 + concat_image = Image.new('RGB', (total_width, H), color='black') + + # 拼接每一帧及分隔条 + for j in range(N): + frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) + frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') + + # 计算当前帧在拼接图像中的位置 + x_position = j * (W + separator_width) + concat_image.paste(frame_image, (x_position, 0)) + + # 显示图像 + plt.imshow(concat_image) + plt.title(f'Sample {i+1}') + plt.axis('off') # 关闭坐标轴显示 + plt.show() + + # 保存图像到文件 + concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/muzero_gpt_model.py b/lzero/model/muzero_gpt_model.py index 85cdde640..ded983938 100644 --- a/lzero/model/muzero_gpt_model.py +++ b/lzero/model/muzero_gpt_model.py @@ -166,8 +166,8 @@ def __init__( embedding_dim=cfg.world_model.embed_dim, ) # Instantiate the decoder - decoder_network = LatentDecoder(embedding_dim=cfg.world_model.embed_dim, output_shape=(4, 64, 64)) # TODO: For K=4 - # decoder_network = LatentDecoder(embedding_dim=cfg.world_model.embed_dim, output_shape=(3, 64, 64)) # TODO: For K=1 + # decoder_network = LatentDecoder(embedding_dim=cfg.world_model.embed_dim, output_shape=(4, 64, 64)) # TODO: For K=4 + decoder_network = LatentDecoder(embedding_dim=cfg.world_model.embed_dim, output_shape=(3, 64, 64)) # TODO: For K=1 Encoder = Encoder(cfg.tokenizer.encoder) diff --git a/lzero/policy/muzero_gpt.py b/lzero/policy/muzero_gpt.py index c5bb593ed..6878b7a83 100644 --- a/lzero/policy/muzero_gpt.py +++ b/lzero/policy/muzero_gpt.py @@ -510,7 +510,6 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni reward_loss = self.intermediate_losses['loss_rewards'] policy_loss = self.intermediate_losses['loss_policy'] value_loss = self.intermediate_losses['loss_value'] - latent_kl_loss = self.intermediate_losses['latent_kl_loss'] latent_recon_loss = self.intermediate_losses['latent_recon_loss'] perceptual_loss = self.intermediate_losses['perceptual_loss'] orig_policy_loss = self.intermediate_losses['orig_policy_loss'] @@ -539,7 +538,6 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value) # TODO # total_grad_norm_before_clip_rep_net = torch.nn.utils.clip_grad_norm_(self._learn_model.tokenizer.representation_network.parameters(), max_norm=1.0) - # print('total_grad_norm_before_clip_rep_net:', total_grad_norm_before_clip_rep_net) @@ -577,7 +575,6 @@ def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Uni 'weighted_total_loss': weighted_total_loss.item(), 'obs_loss': obs_loss, - 'latent_kl_loss': latent_kl_loss, 'latent_recon_loss':latent_recon_loss, 'perceptual_loss':perceptual_loss, 'policy_loss': policy_loss, @@ -752,6 +749,8 @@ def _init_eval(self) -> None: self._mcts_eval = MCTSCtree(self._cfg) else: self._mcts_eval = MCTSPtree(self._cfg) + self.last_batch_obs = torch.zeros([3,self._cfg.model.observation_shape[0],64,64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(3)] def _get_target_obs_index_in_step_k(self, step): """ @@ -807,7 +806,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} # network_output = self._collect_model.initial_inference(data) - network_output = self._eval_model.initial_inference(data) + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) if not self._eval_model.training: @@ -835,6 +834,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 if ready_env_id is None: ready_env_id = np.arange(active_eval_env_num) + + batch_action = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] @@ -862,6 +863,10 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 'predicted_value': pred_values[i], 'predicted_policy_logits': policy_logits[i], } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action return output @@ -886,7 +891,6 @@ def _monitor_vars_learn(self) -> List[str]: 'policy_loss', 'orig_policy_loss', 'policy_entropy', - 'latent_kl_loss', 'latent_recon_loss', # 'policy_entropy', 'target_policy_entropy', diff --git a/zoo/atari/config/atari_xzero_config_stack1.py b/zoo/atari/config/atari_xzero_config_stack1.py index eef9ae663..ff26f8b1b 100644 --- a/zoo/atari/config/atari_xzero_config_stack1.py +++ b/zoo/atari/config/atari_xzero_config_stack1.py @@ -1,13 +1,13 @@ from easydict import EasyDict import torch -torch.cuda.set_device(0) +torch.cuda.set_device(3) # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} -env_name = 'PongNoFrameskip-v4' +# env_name = 'PongNoFrameskip-v4' # env_name = 'MsPacmanNoFrameskip-v4' -# env_name = 'BreakoutNoFrameskip-v4' # env_name = 'QbertNoFrameskip-v4' # env_name = 'SeaquestNoFrameskip-v4' +env_name = 'BreakoutNoFrameskip-v4' # collect_env_steps=5e3 # env_name = 'BoxingNoFrameskip-v4' # env_name = 'FrostbiteNoFrameskip-v4' @@ -33,42 +33,23 @@ # ============================================================== collector_env_num = 8 n_episode = 8 -evaluator_env_num = 1 -update_per_collect = 1000 -# update_per_collect = None - - -# collector_env_num = 16 -# n_episode = 16 -# evaluator_env_num = 1 -# update_per_collect = 2000 - - -# update_per_collect = None -# model_update_ratio = 1 # for qbet squest -# model_update_ratio = 0.25 # for pong boxing -model_update_ratio = 0.125 # for pong boxing +evaluator_env_num = 3 +# update_per_collect = 1000 # for pong boxing +update_per_collect = None # for others +model_update_ratio = 0.25 num_simulations = 50 -# num_simulations = 100 - - - -# max_env_step = int(2e5) -max_env_step = int(10e6) +max_env_step = int(2e6) reanalyze_ratio = 0. -# reanalyze_ratio = 0.05 - +# reanalyze_ratio = 0.05 # TODO batch_size = 64 num_unroll_steps = 5 # num_unroll_steps = 10 - # eps_greedy_exploration_in_collect = True eps_greedy_exploration_in_collect = False - # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -81,24 +62,11 @@ # atari env action space # game_buffer_muzero_gpt task_id # TODO: muzero_gpt_model.py world_model.py (3,64,64) - exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_simnorm_latentw10_pew0_seed0', - # exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_simnorm_latentw10_pew1e-4_seed0', - - # exp_name=f'data_xzero_0306/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh4_fixroot_head-2-layer_mantrans-nobatch_seed0', - - # exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh4_fixroot_head-2-layer_pttrans-batch_seed0', - # exp_name=f'data_xzero_0305/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh4_fixroot_head-1-layer-havebias_seed0', + exp_name=f'data_xzero_atari_0316/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-groupkl_fixed-act-emb_nogradscale_seed0', - # exp_name=f'data_xzero_0305/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh4_fixroot_seed0', - - # exp_name=f'data_xzero_stack1_0226/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kv-forloop-quantize15-lsd768-nh4_noload_seed0', - - # exp_name=f'data_xzero_stack1_0226/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd1024-nh8_seed0', - # exp_name=f'data_xzero_stack1_0226/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh2_collect-clear200_train-clear20_noeval_search-toplay-nodeepcopy_seed0', - # exp_name=f'data_profile/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min_collect-clear200_noeval_search-toplay-nodeepcopy_seed0', - # exp_name=f'data_xzero_stack1_0219/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_contembdings_lsd1024_lr1e-4-reconlwperlw-005-minmax_mcts-kv-reset-5-kv-81_latent-soft-target-100_mantran_seed0', - # exp_name=f'data_xzero_stack1_0219/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_contembdings_lsd1024_lr1e-4-reconlwperlw-005-minmax-jointtrain-true_mcs5e3_collectper200-clear_mcts-kv-reset-5-kv-81-base-fix_latent-sigmod_latent-soft-target-100_mantran_seed0', - env=dict( + # exp_name=f'data_xzero_0312/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-groupkl_nogradscale_seed0', + # exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_simnorm_latentw10_pew1e-4_seed0', + env=dict( stop_value=int(1e6), env_name=env_name, # obs_shape=(4, 96, 96), @@ -115,14 +83,13 @@ n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), # TODO: debug - # collect_max_episode_steps=int(100), - # eval_max_episode_steps=int(100), + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), # TODO: run - collect_max_episode_steps=int(2e4), + # collect_max_episode_steps=int(5e3), # for breakout + collect_max_episode_steps=int(2e4), # for others eval_max_episode_steps=int(1e4), - # collect_max_episode_steps=int(2e4), # eval_max_episode_steps=int(108000), - # clip_rewards=False, clip_rewards=True, ), policy=dict( @@ -135,19 +102,10 @@ ), model_path=None, # model_path='/mnt/afs/niuyazhe/code/LightZero/data_xzero_stack1_0226/Pong_xzero_envnum8_ns50_upc1000-mur0.25_new-rr0.0_H5_bs64_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh4_collect-clear200_train-clear20_noeval_search-toplay-nodeepcopy_seed0/ckpt/iteration_220000.pth.tar', - # model_path='/mnt/afs/niuyazhe/code/LightZero/data_xzero_stack1_0204/Pong_xzero_envnum8_ns50_upc1000-mur0.25_new-rr0.0_H5_bs64_stack1_contembdings_lsd1024_lr1e-4-reconlwperlw-005-minmax-jointtrain-true_mcs5e3_collectper200-clear_mcts-kv-reset-5-kv-88-base_latent-sigmod_latent-soft-target-100_mantran_seed0/ckpt/ckpt_best.pth.tar', - # model_path='/mnt/afs/niuyazhe/code/LightZero/data_xzero_stack1_0204/Pong_xzero_envnum8_ns50_upc1000-mur0.25_new-rr0.25_H5_bs64_stack1_contembdings_lsd1024_lr1e-4-reconlwperlw-005-minmax-jointtrain-true_mcs5e2_collectper200-clear_target-per20-clear_evalmax_latent-soft-target-001_mantran_seed0/ckpt/ckpt_best.pth.tar', - # model_path='/mnt/afs/niuyazhe/code/LightZero/data_mz_gpt_ctree_0113_k1/Pong_muzero_gpt_envnum8_ns50_upc1000-mur0.25_rr0_H5_bs32_stack1_contembdings_lsd1024_lr1e-4-gcv10-reconslossw005-minmax-jointtrain-true_mcs5e2_collectper200-clear_evalmax_seed0/ckpt/iteration_167000.pth.tar', - # model_path='/mnt/afs/niuyazhe/code/LightZero/data_mz_ctree/Pong_muzero_ns50_upc1000_rr0.0_46464_seed0_240110_140819/ckpt/iteration_60000.pth.tar', - # tokenizer_start_after_envsteps=int(9e9), # not train tokenizer tokenizer_start_after_envsteps=int(0), transformer_start_after_envsteps=int(0), - # tokenizer_start_after_envsteps=int(0), - # transformer_start_after_envsteps=int(2e4), # 20K - # transformer_start_after_envsteps=int(5e3), # 5K 1K-5K 4000步 update_per_collect_transformer=update_per_collect, update_per_collect_tokenizer=update_per_collect, - # transformer_start_after_envsteps=int(5e3), num_unroll_steps=num_unroll_steps, model=dict( # observation_shape=(4, 96, 96), @@ -180,14 +138,11 @@ # reward_support_size=21, # value_support_size=21, # support_scale=10, - # embedding_dim=1024, - # embedding_dim=256, ), use_priority=False, cuda=True, env_type='not_board_games', game_segment_length=400, - # game_segment_length=50, random_collect_episode_num=0, eps=dict( eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, @@ -196,14 +151,9 @@ type='linear', start=1., end=0.01, - # decay=int(1e5), decay=int(1e4), # 10k - # decay=int(5e4), # 50k - # decay=int(5e3), # 5k ), - # TODO: NOTE - # use_augmentation=True, - use_augmentation=False, + use_augmentation=False, # NOTE update_per_collect=update_per_collect, model_update_ratio = model_update_ratio, batch_size=batch_size, @@ -216,20 +166,17 @@ optim_type='Adam', lr_piecewise_constant_decay=False, - # learning_rate=0.003, learning_rate=0.0001, target_update_freq=100, - grad_clip_value = 0.5, # TODO - # grad_clip_value = 10, # TODO + grad_clip_value = 0.5, # TODO: 10 num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, ssl_loss_weight=2, # default is 0 n_episode=n_episode, - # eval_freq=int(5e3), - eval_freq=int(9e9), - # eval_freq=int(1e5), + # eval_freq=int(9e9), + eval_freq=int(1e4), replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -244,7 +191,6 @@ import_names=['zoo.atari.envs.atari_lightzero_env'], ), env_manager=dict(type='subprocess'), - # env_manager=dict(type='base'), policy=dict( type='muzero_gpt', import_names=['lzero.policy.muzero_gpt'], diff --git a/zoo/atari/config/atari_xzero_config_stack1_debug.py b/zoo/atari/config/atari_xzero_config_stack1_debug.py index b32096823..4775eb0c8 100644 --- a/zoo/atari/config/atari_xzero_config_stack1_debug.py +++ b/zoo/atari/config/atari_xzero_config_stack1_debug.py @@ -37,7 +37,7 @@ # collector_env_num = 1 # n_episode = 1 -evaluator_env_num = 1 +evaluator_env_num = 3 update_per_collect = 1000 # collector_env_num = 1 @@ -120,8 +120,8 @@ save_ckpt_after_run=True, ), ), - # model_path=None, - model_path='/mnt/afs/niuyazhe/code/LightZero/data_xzero_0307/Pong_xzero_envnum8_ns50_upc1000-mur0.125_new-rr0.0_H5_bs64_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_seed0/ckpt/iteration_60000.pth.tar', + model_path=None, + # model_path='/mnt/afs/niuyazhe/code/LightZero/data_xzero_0307/Pong_xzero_envnum8_ns50_upc1000-mur0.125_new-rr0.0_H5_bs64_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_seed0/ckpt/iteration_60000.pth.tar', # model_path='/mnt/afs/niuyazhe/code/LightZero/data_xzero_stack1_0204/Pong_xzero_envnum8_ns50_upc1000-mur0.25_new-rr0.0_H5_bs64_stack1_contembdings_lsd1024_lr1e-4-reconlwperlw-005-minmax-jointtrain-true_mcs5e3_collectper200-clear_mcts-kv-reset-5-kv-88-base_latent-sigmod_latent-soft-target-100_mantran_seed0/ckpt/ckpt_best.pth.tar', # model_path='/mnt/afs/niuyazhe/code/LightZero/data_xzero_stack1_0204/Pong_xzero_envnum8_ns50_upc1000-mur0.25_new-rr0.25_H5_bs64_stack1_contembdings_lsd1024_lr1e-4-reconlwperlw-005-minmax-jointtrain-true_mcs5e2_collectper200-clear_target-per20-clear_evalmax_latent-soft-target-001_mantran_seed0/ckpt/ckpt_best.pth.tar', # model_path='/mnt/afs/niuyazhe/code/LightZero/data_mz_gpt_ctree_0113_k1/Pong_muzero_gpt_envnum8_ns50_upc1000-mur0.25_rr0_H5_bs32_stack1_contembdings_lsd1024_lr1e-4-gcv10-reconslossw005-minmax-jointtrain-true_mcs5e2_collectper200-clear_evalmax_seed0/ckpt/iteration_167000.pth.tar', @@ -214,8 +214,9 @@ reanalyze_ratio=reanalyze_ratio, ssl_loss_weight=2, # default is 0 n_episode=n_episode, - # eval_freq=int(5e3), - eval_freq=int(9e9), + eval_freq=int(5e3), + # eval_freq=int(9e9), + # eval_freq=int(1), # eval_freq=int(1e5), replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. collector_env_num=collector_env_num,