diff --git a/lzero/entry/train_muzero_gpt.py b/lzero/entry/train_muzero_gpt.py index b74c2cf58..000fff3de 100644 --- a/lzero/entry/train_muzero_gpt.py +++ b/lzero/entry/train_muzero_gpt.py @@ -156,9 +156,9 @@ def train_muzero_gpt( else: collect_kwargs['epsilon'] = 0.0 - # policy.last_batch_obs = torch.zeros([len(evaluator_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device) - # policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))] - # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + policy.last_batch_obs = torch.zeros([len(evaluator_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device) + policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))] + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) # Evaluate policy performance. if evaluator.should_eval(learner.train_iter): diff --git a/lzero/model/gpt_models/slicer_v0_bkp.py b/lzero/model/gpt_models/slicer_v0_bkp.py deleted file mode 100644 index 6566271fa..000000000 --- a/lzero/model/gpt_models/slicer_v0_bkp.py +++ /dev/null @@ -1,54 +0,0 @@ -import math -from typing import List - -import torch -import torch.nn as nn - - -class Slicer(nn.Module): - def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None: - super().__init__() - self.block_size = block_mask.size(0) - self.num_kept_tokens = block_mask.sum().long().item() - kept_indices = torch.where(block_mask)[0].repeat(max_blocks) - offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens) - self.register_buffer('indices', kept_indices + block_mask.size(0) * offsets) - - def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor: - total_steps = num_steps + prev_steps - num_blocks = math.ceil(total_steps / self.block_size) - indices = self.indices[:num_blocks * self.num_kept_tokens] - return indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps - - def forward(self, *args, **kwargs): - raise NotImplementedError - - -class Head(Slicer): - def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Module) -> None: - super().__init__(max_blocks, block_mask) - assert isinstance(head_module, nn.Module) - self.head_module = head_module - - def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor: - x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E) - return self.head_module(x_sliced) - - -class Embedder(nn.Module): - def __init__(self, max_blocks: int, block_masks: List[torch.Tensor], embedding_tables: List[nn.Embedding]) -> None: - super().__init__() - assert len(block_masks) == len(embedding_tables) - assert (sum(block_masks) == 1).all() # block mask are a partition of a block - self.embedding_dim = embedding_tables[0].embedding_dim - assert all([e.embedding_dim == self.embedding_dim for e in embedding_tables]) - self.embedding_tables = embedding_tables - self.slicers = [Slicer(max_blocks, block_mask) for block_mask in block_masks] - - def forward(self, tokens: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor: - assert tokens.ndim == 2 # x is (B, T) - output = torch.zeros(*tokens.size(), self.embedding_dim, device=tokens.device) - for slicer, emb in zip(self.slicers, self.embedding_tables): - s = slicer.compute_slice(num_steps, prev_steps) - output[:, s] = emb(tokens[:, s]) - return output diff --git a/lzero/model/gpt_models/slicer_v1.py b/lzero/model/gpt_models/slicer_v1.py deleted file mode 100644 index d259e9e98..000000000 --- a/lzero/model/gpt_models/slicer_v1.py +++ /dev/null @@ -1,66 +0,0 @@ -import math -from typing import List - -import torch -import torch.nn as nn - -from collections import defaultdict - -class Slicer(nn.Module): - def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None: - super().__init__() - self.block_size = block_mask.size(0) - self.num_kept_tokens = block_mask.sum().long().item() - kept_indices = torch.where(block_mask)[0].repeat(max_blocks) - offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens) - # 17*20 的所有token中 保留的token的索引 - self.register_buffer('indices', kept_indices + block_mask.size(0) * offsets) - self.cache = defaultdict(torch.Tensor) - - def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor: - cache_key = (num_steps, prev_steps) - if cache_key not in self.cache: - total_steps = num_steps + prev_steps - num_blocks = math.ceil(total_steps / self.block_size) - indices = self.indices[:num_blocks * self.num_kept_tokens] - result = indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps - self.cache[cache_key] = result - return self.cache[cache_key] - # total_steps = num_steps + prev_steps - # num_blocks = math.ceil(total_steps / self.block_size) - # indices = self.indices[:num_blocks * self.num_kept_tokens] - # result = indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps - # return result - - def forward(self, *args, **kwargs): - raise NotImplementedError - - -class Head(Slicer): - def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Module) -> None: - super().__init__(max_blocks, block_mask) - assert isinstance(head_module, nn.Module) - self.head_module = head_module - - def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor: - x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E) - return self.head_module(x_sliced) - - -class Embedder(nn.Module): - def __init__(self, max_blocks: int, block_masks: List[torch.Tensor], embedding_tables: List[nn.Embedding]) -> None: - super().__init__() - assert len(block_masks) == len(embedding_tables) - assert (sum(block_masks) == 1).all() # block mask are a partition of a block - self.embedding_dim = embedding_tables[0].embedding_dim - assert all([e.embedding_dim == self.embedding_dim for e in embedding_tables]) - self.embedding_tables = embedding_tables - self.slicers = [Slicer(max_blocks, block_mask) for block_mask in block_masks] - - def forward(self, tokens: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor: - assert tokens.ndim == 2 # x is (B, T) - output = torch.zeros(*tokens.size(), self.embedding_dim, device=tokens.device) - for slicer, emb in zip(self.slicers, self.embedding_tables): - s = slicer.compute_slice(num_steps, prev_steps) - output[:, s] = emb(tokens[:, s]) - return output diff --git a/lzero/model/gpt_models/slicer_v2.py b/lzero/model/gpt_models/slicer_v2.py deleted file mode 100644 index 5f7523901..000000000 --- a/lzero/model/gpt_models/slicer_v2.py +++ /dev/null @@ -1,69 +0,0 @@ -import math -from typing import List - -import torch -import torch.nn as nn - -class Slicer(nn.Module): - def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None: - super().__init__() - self.block_size = block_mask.size(0) - self.num_kept_tokens = block_mask.sum().long().item() - kept_indices = torch.where(block_mask)[0].repeat(max_blocks) - offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens) - self.register_buffer('indices', kept_indices + block_mask.size(0) * offsets) - self.cache: Dict[str, torch.Tensor] = {} - self.precompute_slices() - - def precompute_slices(self) -> None: - for num_steps in range(self.block_size*20): - for prev_steps in range(self.block_size*20): - cache_key = f"{num_steps}_{prev_steps}" - total_steps = num_steps + prev_steps - num_blocks = math.ceil(total_steps / self.block_size) - indices = self.indices[:num_blocks * self.num_kept_tokens] - result = indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps - self.cache[cache_key] = result - - def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor: - cache_key = f"{num_steps}_{prev_steps}" - if cache_key in self.cache: - return self.cache[cache_key] - else: - # Handle the case where cache_key is not in self.cache - # You could return a default value, raise an exception, or compute the result on the fly - # For example, to raise an exception: - raise ValueError(f"Cache key {cache_key} not found in precomputed slices") - - def forward(self, *args, **kwargs): - raise NotImplementedError - - -class Head(Slicer): - def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Module) -> None: - super().__init__(max_blocks, block_mask) - assert isinstance(head_module, nn.Module) - self.head_module = head_module - - def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor: - x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E) - return self.head_module(x_sliced) - - -class Embedder(nn.Module): - def __init__(self, max_blocks: int, block_masks: List[torch.Tensor], embedding_tables: List[nn.Embedding]) -> None: - super().__init__() - assert len(block_masks) == len(embedding_tables) - assert (sum(block_masks) == 1).all() # block mask are a partition of a block - self.embedding_dim = embedding_tables[0].embedding_dim - assert all([e.embedding_dim == self.embedding_dim for e in embedding_tables]) - self.embedding_tables = embedding_tables - self.slicers = [Slicer(max_blocks, block_mask) for block_mask in block_masks] - - def forward(self, tokens: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor: - assert tokens.ndim == 2 # x is (B, T) - output = torch.zeros(*tokens.size(), self.embedding_dim, device=tokens.device) - for slicer, emb in zip(self.slicers, self.embedding_tables): - s = slicer.compute_slice(num_steps, prev_steps) - output[:, s] = emb(tokens[:, s]) - return output diff --git a/lzero/model/gpt_models/slicer_v3.py b/lzero/model/gpt_models/slicer_v3.py deleted file mode 100644 index 41ad48145..000000000 --- a/lzero/model/gpt_models/slicer_v3.py +++ /dev/null @@ -1,62 +0,0 @@ -import math -from typing import List - -import torch -import torch.nn as nn - -class Slicer(nn.Module): - def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None: - super().__init__() - self.block_size = block_mask.size(0) - self.num_kept_tokens = block_mask.sum().long().item() - kept_indices = torch.where(block_mask)[0].repeat(max_blocks) - offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens) - self.indices = kept_indices + block_mask.size(0) * offsets - - print("precompute_slices() begin") - self.cache = {} - max_steps = max_blocks * self.block_size - for num_steps in range(max_steps): - for prev_steps in range(max_steps): - total_steps = num_steps + prev_steps - num_blocks = math.ceil(total_steps / self.block_size) - indices = self.indices[:num_blocks * self.num_kept_tokens] - result = indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps - self.cache[(num_steps, prev_steps)] = result - print("precompute_slices() done") - - def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor: - return self.cache[(num_steps, prev_steps)] - - def forward(self, *args, **kwargs): - raise NotImplementedError - - -class Head(Slicer): - def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Module) -> None: - super().__init__(max_blocks, block_mask) - assert isinstance(head_module, nn.Module) - self.head_module = head_module - - def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor: - x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E) - return self.head_module(x_sliced) - - -class Embedder(nn.Module): - def __init__(self, max_blocks: int, block_masks: List[torch.Tensor], embedding_tables: List[nn.Embedding]) -> None: - super().__init__() - assert len(block_masks) == len(embedding_tables) - assert (sum(block_masks) == 1).all() # block mask are a partition of a block - self.embedding_dim = embedding_tables[0].embedding_dim - assert all([e.embedding_dim == self.embedding_dim for e in embedding_tables]) - self.embedding_tables = embedding_tables - self.slicers = [Slicer(max_blocks, block_mask) for block_mask in block_masks] - - def forward(self, tokens: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor: - assert tokens.ndim == 2 # x is (B, T) - output = torch.zeros(*tokens.size(), self.embedding_dim, device=tokens.device) - for slicer, emb in zip(self.slicers, self.embedding_tables): - s = slicer.compute_slice(num_steps, prev_steps) - output[:, s] = emb(tokens[:, s]) - return output diff --git a/lzero/model/gpt_models/test_slicer_time_v0.py b/lzero/model/gpt_models/test_slicer_time_v0.py deleted file mode 100644 index 306652e64..000000000 --- a/lzero/model/gpt_models/test_slicer_time_v0.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -import pytest -from slicer import Slicer, Head, Embedder - - -def test_slicer_time(): - max_blocks = 20 - act_tokens_pattern = torch.zeros(17) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - block_mask = act_tokens_pattern - - slicer = Slicer(max_blocks, block_mask) - - import timeit - start_time = timeit.default_timer() - # code you want to evaluate - # Test slice computation - slice_ = slicer.compute_slice(num_steps=5, prev_steps=2) - - end_time = timeit.default_timer() - execution_time = end_time - start_time - print(f"Executed the function in {execution_time} seconds") - - assert torch.equal(slice_, torch.tensor([])) - - # Test caching - cache_key = (5, 2) - assert cache_key in slicer.cache - assert torch.equal(slice_, slicer.cache[cache_key]) - - - -if __name__ == "__main__": - test_slicer_time() - diff --git a/lzero/model/gpt_models/world_model_batch_nopad_max.py b/lzero/model/gpt_models/world_model_batch_nopad_max.py deleted file mode 100644 index f2b0b1f44..000000000 --- a/lzero/model/gpt_models/world_model_batch_nopad_max.py +++ /dev/null @@ -1,1013 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - self.num_layers = config.num_layers - - - all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, latent_state_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - # self.act_embedder = ActEmbedder( - # max_blocks=config.max_blocks, - # block_masks=[act_tokens_pattern], - # embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim)]) - # ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - # self.head_observations = Head( - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - # ) - # ) - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.BatchNorm1d(config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - ) - ) - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - # TODO: policy init : 3 - # Locate the last linear layer and initialize its weights and biases to 0. - # for _, layer in enumerate(reversed(self.head_policy.head_module)): - # if isinstance(layer, nn.Linear): - # nn.init.zeros_(layer.weight) - # nn.init.zeros_(layer.bias) - # break - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - # NOTE - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.config.max_tokens) - - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - - - - def __repr__(self) -> str: - return "world_model" - - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False, kvcache_independent=False) -> WorldModelOutput: - - if kvcache_independent: - prev_steps = 0 if past_keys_values is None else [past_kv.size for past_kv in past_keys_values] - prev_steps = torch.tensor(prev_steps, device=self.device) - # 我们需要为每个样本生成一个序列的步骤indices,然后获取它们的位置嵌入 - # 首先扩展prev_steps至(num_steps, batch_size),这里num_steps=1 - # prev_steps = prev_steps.unsqueeze(0) - - else: - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - if len(obs_embeddings.shape)==2: - obs_embeddings = obs_embeddings.unsqueeze(1) - num_steps = obs_embeddings.size(1) # (B, T, E) - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device) - - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = obs_embeddings + position_embeddings - else: - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - num_steps = act_tokens.size(1) # (B, T) - act_embeddings = self.act_embedding_table(act_tokens) - - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device) - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = act_embeddings + position_embeddings - else: - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - if kvcache_independent: - x = [] - for k, past_kv in enumerate(past_keys_values): - x.append(self.transformer(sequences[k].unsqueeze(0), past_kv)) - # self.keys_values_wm_list[k] = past_kv # NOTE: todo - x = torch.cat(x, dim=0) - - # TODO: 在collect时,是一步一步的 obs act 传入的 - # prev_steps = prev_steps//1 - - else: - x = self.transformer(sequences, past_keys_values) - - # print('transformer forward done') - - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits, 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - # Compute the hash of latest_state - # latest_state = latent_state.detach().cpu().numpy() - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - # ready_env_num = latest_state.shape[0] - # keys_values_wm_list = [] - # self.keys_values_wm_size_list = [] - # for i in range(ready_env_num): - # self.total_query_count += 1 - # state_single_env = latest_state[i] # 获取单个环境的 latent state - # hash_latest_state = hash(state_single_env) # 计算哈希值 - # matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值 - # if matched_value is not None: - # self.hit_count += 1 - # # 如果找到匹配的值,将其添加到列表中 - # keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - # self.keys_values_wm_size_list.append(matched_value.size) - - # keys_values_wm_list.append(self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)) - # self.keys_values_wm_size_list.append(0) - # else: - # # use zero - # keys_values_wm_list.append(self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)) - # self.keys_values_wm_size_list.append(0) - # self.keys_values_wm_list = keys_values_wm_list - - # use zero - # self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - # outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm_list, is_root=False, kvcache_independent=True) - # self.keys_values_wm_size_list = [1 for i in range(n)] - - - self.keys_values_wm = self.transformer.generate_empty_keys_values(n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - self.keys_values_wm_size_list = [1 for i in range(n)] - - for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - # 复制单个环境对应的 keys_values_wm 并存储 - keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for layer in range(self.num_layers): - keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(keys_values_wm_single_env, 'cpu')) - del keys_values_wm_single_env - - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - elif n > self.env_num and n != int(256) and buffer_action is not None: - # train时计算target value - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor, reset_indices=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - if reset_indices is None: - self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - else: - for i in reset_indices: - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state[i].unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False, kvcache_independent=False) - self.keys_values_wm_list[i] = self.keys_values_wm_single_env - self.keys_values_wm_size_list[i] = 1 - return None - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - if isinstance(obs_act_dict, dict): - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # TODO - # outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - - """ - past-kv-dict-batch envnum8 latest multi-step - fix init infer - 把8个样本的self.keys_values_wm 看做一个整体来寻找 - - TODO:很多时候都是执行的refresh_keys_values_with_initial_latent_state,导致没有充分利用序列建模能力? - """ - - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - if self.total_query_count>0: - self.hit_freq = self.hit_count/(self.total_query_count) - print('hit_freq:', self.hit_freq) - print('hit_count:', self.hit_count) - print('total_query_count:', self.total_query_count) - - latest_state = state_action_history[-1][0] - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - for i in range(ready_env_num): - self.total_query_count += 1 - state_single_env = latest_state[i] # 获取单个环境的 latent state - hash_latest_state = hash(state_single_env) # 计算哈希值 - matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值 - if matched_value is not None: - self.hit_count += 1 - # 如果找到匹配的值,将其添加到列表中 - keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - self.keys_values_wm_size_list.append(matched_value.size) - else: - # use zero - # keys_values_wm_list.append(self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)) - # self.keys_values_wm_size_list.append(0) - - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - keys_values_wm_list.append(self.keys_values_wm_single_env) - self.keys_values_wm_size_list.append(1) - - self.keys_values_wm_list = keys_values_wm_list - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - output_sequence, latent_state = [], [] - - print(self.keys_values_wm_size_list) - reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens] - self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices) - - action = state_action_history[-1][-1] - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - # 将每个环境独立的kv_cache拼接为一个batch_size为env_num的总的kv_cache - for layer in range(self.num_layers): - kv_cache_k_list = [] - kv_cache_v_list = [] - for keys_values in self.keys_values_wm_list: - kv_cache_k_list.append(keys_values[layer]._k_cache._cache) - kv_cache_v_list.append(keys_values[layer]._v_cache._cache) - - self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - # self.keys_values_wm._keys_values[layer].update(torch.stack(kv_cache_k_list, dim=0).squeeze(1), torch.stack(kv_cache_v_list, dim=0).squeeze(1)) - self.keys_values_wm._keys_values[layer]._k_cache._size = max(self.keys_values_wm_size_list) # TODO: very important - self.keys_values_wm._keys_values[layer]._v_cache._size = max(self.keys_values_wm_size_list) - - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+1 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - - # outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm_list, is_root=False, kvcache_independent=True) - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - # if len(token.shape) != 2: - # token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - if len(token.shape) != 3: - token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) - latent_state.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory - del self.latent_state - self.latent_state = torch.cat(latent_state, dim=1) # (B, K) - latent_state = self.latent_state - - # # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - # for i in range(latent_state.shape[0]): # 遍历每个环境 - # state_single_env = latent_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - # # 复制单个环境对应的 keys_values_wm 并存储 - # # print([self.keys_values_wm_list[i].size for i in range(8)]) - # self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_list[i], 'cpu')) - - # del keys_values_wm_single_env - for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - # 复制单个环境对应的 keys_values_wm 并存储 - keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for layer in range(self.num_layers): - keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(keys_values_wm_single_env, 'cpu')) - del keys_values_wm_single_env - - # outputs_wm.logits_policy, outputs_wm.logits_value - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - del popped_kv_cache # 不要这一行 - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs - - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_batch_pad_max.py b/lzero/model/gpt_models/world_model_batch_pad_max.py deleted file mode 100644 index 7b6933572..000000000 --- a/lzero/model/gpt_models/world_model_batch_pad_max.py +++ /dev/null @@ -1,992 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - self.num_layers = config.num_layers - - - all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, latent_state_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - # self.act_embedder = ActEmbedder( - # max_blocks=config.max_blocks, - # block_masks=[act_tokens_pattern], - # embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim)]) - # ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - # self.head_observations = Head( - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - # ) - # ) - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.BatchNorm1d(config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - ) - ) - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - # TODO: policy init : 3 - # Locate the last linear layer and initialize its weights and biases to 0. - # for _, layer in enumerate(reversed(self.head_policy.head_module)): - # if isinstance(layer, nn.Linear): - # nn.init.zeros_(layer.weight) - # nn.init.zeros_(layer.bias) - # break - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - # NOTE - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.config.max_tokens) - - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - - - - def __repr__(self) -> str: - return "world_model" - - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False, kvcache_independent=False) -> WorldModelOutput: - - if kvcache_independent: - prev_steps = 0 if past_keys_values is None else [past_kv.size for past_kv in past_keys_values] - prev_steps = torch.tensor(prev_steps, device=self.device) - # 我们需要为每个样本生成一个序列的步骤indices,然后获取它们的位置嵌入 - # 首先扩展prev_steps至(num_steps, batch_size),这里num_steps=1 - # prev_steps = prev_steps.unsqueeze(0) - - else: - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - if len(obs_embeddings.shape)==2: - obs_embeddings = obs_embeddings.unsqueeze(1) - num_steps = obs_embeddings.size(1) # (B, T, E) - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device) - - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = obs_embeddings + position_embeddings - else: - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - num_steps = act_tokens.size(1) # (B, T) - act_embeddings = self.act_embedding_table(act_tokens) - - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device) - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = act_embeddings + position_embeddings - else: - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - if kvcache_independent: - x = [] - for k, past_kv in enumerate(past_keys_values): - x.append(self.transformer(sequences[k].unsqueeze(0), past_kv)) - # self.keys_values_wm_list[k] = past_kv # NOTE: todo - x = torch.cat(x, dim=0) - - # TODO: 在collect时,是一步一步的 obs act 传入的 - # prev_steps = prev_steps//1 - - else: - x = self.transformer(sequences, past_keys_values) - - # print('transformer forward done') - - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits, 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - self.keys_values_wm_size_list = [1 for i in range(n)] - - # 复制单个环境对应的 keys_values_wm 并存储 - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - # del keys_values_wm_single_env - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - elif n > self.env_num and n != int(256) and buffer_action is not None: - # train时计算target value - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor, reset_indices=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - if reset_indices is None: - self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - else: - for i in reset_indices: - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state[i].unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False, kvcache_independent=False) - self.keys_values_wm_list[i] = self.keys_values_wm_single_env - self.keys_values_wm_size_list[i] = 1 - return None - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - if isinstance(obs_act_dict, dict): - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # TODO - # outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - """ - 假设env_num=8 - 8个环境的kv_cache单独存储与寻找,都存储在一个dict中,在recurrent_inference时, - 由于不同环境找到的kv_cache的size不同,先根据最大size对kv_cache在前部补零,然后组成batch_size的kv_cache - 其内部也是通过batch执行transformer forward的推理 - """ - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - if self.total_query_count>0 and self.total_query_count%10000==0: - self.hit_freq = self.hit_count/(self.total_query_count) - print('hit_freq:', self.hit_freq) - print('hit_count:', self.hit_count) - print('total_query_count:', self.total_query_count) - print(self.keys_values_wm_size_list) - - latest_state = state_action_history[-1][0] - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - for i in range(ready_env_num): - self.total_query_count += 1 - state_single_env = latest_state[i] # 获取单个环境的 latent state - hash_latest_state = hash(state_single_env) # 计算哈希值 - matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值 - if matched_value is not None: - self.hit_count += 1 - # 如果找到匹配的值,将其添加到列表中 - keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - self.keys_values_wm_size_list.append(matched_value.size) - else: - # use zero - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - keys_values_wm_list.append(self.keys_values_wm_single_env) - self.keys_values_wm_size_list.append(1) - - self.keys_values_wm_list = keys_values_wm_list - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - output_sequence, latent_state = [], [] - - # print(self.keys_values_wm_size_list) - reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens] - self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices) - - action = state_action_history[-1][-1] - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - # 将每个环境独立的kv_cache拼接为一个batch_size为env_num的总的kv_cache - # 计算最大值 - max_size = max(self.keys_values_wm_size_list) - for layer in range(self.num_layers): - # 每层的k和v缓存列表 - kv_cache_k_list = [] - kv_cache_v_list = [] - for idx, keys_values in enumerate(self.keys_values_wm_list): - # 获取当前层的k和v缓存 - k_cache = keys_values[layer]._k_cache._cache - v_cache = keys_values[layer]._v_cache._cache - - # 计算有效大小和需要填充的零的大小 - effective_size = self.keys_values_wm_size_list[idx] - padding_size = max_size - effective_size - # 如果需要填充,则在第三维之前填充零 - if padding_size > 0: - # 使用F.pad进行填充,pad的四个数字分别代表在最后两维前后填充的大小 - k_cache_padded = F.pad(k_cache, (0, 0, padding_size, 0), "constant", 0)[:,:,:self.config.max_tokens,:] - v_cache_padded = F.pad(v_cache, (0, 0, padding_size, 0), "constant", 0)[:,:,:self.config.max_tokens,:] - else: - k_cache_padded = k_cache - v_cache_padded = v_cache - # 将处理过的缓存添加到列表中 - kv_cache_k_list.append(k_cache_padded) - kv_cache_v_list.append(v_cache_padded) - - # 使用torch.stack()将列表中的缓存堆叠起来,形成新的缓存 - self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - # 更新缓存的尺寸 - self.keys_values_wm._keys_values[layer]._k_cache._size = max_size - self.keys_values_wm._keys_values[layer]._v_cache._size = max_size - - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+1 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - - # outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm_list, is_root=False, kvcache_independent=True) - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - # if len(token.shape) != 2: - # token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - if len(token.shape) != 3: - token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) - latent_state.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory - # del self.latent_state - self.latent_state = torch.cat(latent_state, dim=1) # (B, K) - - # # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - # for i in range(latent_state.shape[0]): # 遍历每个环境 - # state_single_env = latent_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - # # 复制单个环境对应的 keys_values_wm 并存储 - # # print([self.keys_values_wm_list[i].size for i in range(8)]) - # self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_list[i], 'cpu')) - - # keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for i in range(self.latent_state.size(0)): # 遍历每个环境 - state_single_env = self.latent_state[i] # 获取单个环境的 latent state - cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - # 复制单个环境对应的 keys_values_wm 并存储 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - # del keys_values_wm_single_env - - # outputs_wm.logits_policy, outputs_wm.logits_value - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - del popped_kv_cache # 不要这一行 - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs - - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_batch_pad_min.py b/lzero/model/gpt_models/world_model_batch_pad_min.py deleted file mode 100644 index cb178d0cb..000000000 --- a/lzero/model/gpt_models/world_model_batch_pad_min.py +++ /dev/null @@ -1,999 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - self.num_layers = config.num_layers - - - all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, latent_state_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - # self.act_embedder = ActEmbedder( - # max_blocks=config.max_blocks, - # block_masks=[act_tokens_pattern], - # embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim)]) - # ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - # self.head_observations = Head( - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - # ) - # ) - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.BatchNorm1d(config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - ) - ) - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - # TODO: policy init : 3 - # Locate the last linear layer and initialize its weights and biases to 0. - # for _, layer in enumerate(reversed(self.head_policy.head_module)): - # if isinstance(layer, nn.Linear): - # nn.init.zeros_(layer.weight) - # nn.init.zeros_(layer.bias) - # break - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - # NOTE - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.config.max_tokens) - - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - - - - def __repr__(self) -> str: - return "world_model" - - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False, kvcache_independent=False) -> WorldModelOutput: - - if kvcache_independent: - prev_steps = 0 if past_keys_values is None else [past_kv.size for past_kv in past_keys_values] - prev_steps = torch.tensor(prev_steps, device=self.device) - # 我们需要为每个样本生成一个序列的步骤indices,然后获取它们的位置嵌入 - # 首先扩展prev_steps至(num_steps, batch_size),这里num_steps=1 - # prev_steps = prev_steps.unsqueeze(0) - - else: - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - if len(obs_embeddings.shape)==2: - obs_embeddings = obs_embeddings.unsqueeze(1) - num_steps = obs_embeddings.size(1) # (B, T, E) - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device) - - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = obs_embeddings + position_embeddings - else: - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - num_steps = act_tokens.size(1) # (B, T) - act_embeddings = self.act_embedding_table(act_tokens) - - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device) - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = act_embeddings + position_embeddings - else: - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - if kvcache_independent: - x = [] - for k, past_kv in enumerate(past_keys_values): - x.append(self.transformer(sequences[k].unsqueeze(0), past_kv)) - # self.keys_values_wm_list[k] = past_kv # NOTE: todo - x = torch.cat(x, dim=0) - - # TODO: 在collect时,是一步一步的 obs act 传入的 - # prev_steps = prev_steps//1 - - else: - x = self.transformer(sequences, past_keys_values) - - # print('transformer forward done') - - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits, 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - self.keys_values_wm_size_list = [1 for i in range(n)] - - # 复制单个环境对应的 keys_values_wm 并存储 - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - if cache_key not in self.past_keys_values_cache: - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - elif n > self.env_num and n != int(256) and buffer_action is not None: - # train时计算target value - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor, reset_indices=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - if reset_indices is None: - self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - else: - for i in reset_indices: - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state[i].unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False, kvcache_independent=False) - self.keys_values_wm_list[i] = self.keys_values_wm_single_env - self.keys_values_wm_size_list[i] = 1 - return None - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - if isinstance(obs_act_dict, dict): - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # TODO - # outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - """ - 假设env_num=8 - 8个环境的kv_cache单独存储与寻找,都存储在一个dict中,在recurrent_inference时, - 由于不同环境找到的kv_cache的size不同,先根据最大size对kv_cache在前部补零,然后组成batch_size的kv_cache - 其内部也是通过batch执行transformer forward的推理 - """ - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - if self.total_query_count>0 and self.total_query_count%99999==0: - # if self.total_query_count>0 and self.total_query_count%1==0: - self.hit_freq = self.hit_count/(self.total_query_count) - print('hit_freq:', self.hit_freq) - print('hit_count:', self.hit_count) - print('total_query_count:', self.total_query_count) - print(self.keys_values_wm_size_list) - - latest_state = state_action_history[-1][0] - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - for i in range(ready_env_num): - self.total_query_count += 1 - state_single_env = latest_state[i] # 获取单个环境的 latent state - hash_latest_state = hash(state_single_env) # 计算哈希值 - matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值 - if matched_value is not None: - # 如果找到匹配的值,将其添加到列表中 - self.hit_count += 1 - # 这里需要deepcopy因为在transformer的forward中会原地修改matched_value - self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - self.keys_values_wm_size_list.append(matched_value.size) - else: - # use zero reset - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - self.keys_values_wm_list.append(self.keys_values_wm_single_env) - self.keys_values_wm_size_list.append(1) - - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - output_sequence, latent_state = [], [] - - # print(self.keys_values_wm_size_list) - reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens] - self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices) - - action = state_action_history[-1][-1] - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - # 获取self.keys_values_wm_size_list的最小值min_size - min_size = min(self.keys_values_wm_size_list) - for layer in range(self.num_layers): - # 每层的k和v缓存列表 - kv_cache_k_list = [] - kv_cache_v_list = [] - - for idx, keys_values in enumerate(self.keys_values_wm_list): - # 获取当前层的k和v缓存 - k_cache = keys_values[layer]._k_cache._cache - v_cache = keys_values[layer]._v_cache._cache - - # 获取当前缓存的有效尺寸 - effective_size = self.keys_values_wm_size_list[idx] - # 计算需要截去的尺寸 - trim_size = effective_size - min_size if effective_size > min_size else 0 - - # 如果需要截去部分,则截去前面的(trim_size)步 - if trim_size > 0: - k_cache_trimmed = k_cache[:, :, trim_size:, :] - v_cache_trimmed = v_cache[:, :, trim_size:, :] - # 在第三维之后补零 - k_cache_padded = F.pad(k_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) - v_cache_padded = F.pad(v_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) - else: - k_cache_padded = k_cache - v_cache_padded = v_cache - - # 将处理过的缓存添加到列表中 - kv_cache_k_list.append(k_cache_padded) - kv_cache_v_list.append(v_cache_padded) - - # 使用torch.stack()将列表中的缓存堆叠起来,形成新的缓存 - self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - - # 更新缓存的尺寸为min_size - self.keys_values_wm._keys_values[layer]._k_cache._size = min_size - self.keys_values_wm._keys_values[layer]._v_cache._size = min_size - # del self.keys_values_wm_list - - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+1 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - # if len(token.shape) != 2: - # token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - if len(token.shape) != 3: - token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) - latent_state.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory - # del self.latent_state - self.latent_state = torch.cat(latent_state, dim=1) # (B, K) - - # # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - for i in range(self.latent_state.size(0)): # 遍历每个环境 - state_single_env = self.latent_state[i] # 获取单个环境的 latent state - cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - # 复制单个环境对应的 keys_values_wm 并存储 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # 比较并存储大小较大的 cache - if cache_key in self.past_keys_values_cache: - existing_kvcache = self.past_keys_values_cache[cache_key] - # 判断现有cache和新cache中的size是否有不同 - if self.keys_values_wm_single_env._keys_values[0]._k_cache._size > existing_kvcache.size and self.keys_values_wm_single_env._keys_values[0]._k_cache._size < self.config.max_tokens-1: - # 具有size为self.config.max_tokens-1 不用存,因为使用达到最大size的kv_cache,而这会导致reset,从而一直是从零开始的 - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - elif self.keys_values_wm_single_env._keys_values[0]._k_cache._size < self.config.max_tokens-1: - # 具有size为self.config.max_tokens-1 不用存,因为使用达到最大size的kv_cache,而这会导致reset,从而一直是从零开始的 - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - - # del self.keys_values_wm - # outputs_wm.logits_policy, outputs_wm.logits_value - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - del popped_kv_cache # 不要这一行 - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs - - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_batch_pad_min_quantize.py b/lzero/model/gpt_models/world_model_batch_pad_min_quantize.py deleted file mode 100644 index cda9f6bc2..000000000 --- a/lzero/model/gpt_models/world_model_batch_pad_min_quantize.py +++ /dev/null @@ -1,1051 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler -import hashlib -# def quantize_state(state, num_buckets=1000): -def quantize_state(state, num_buckets=15): -# def quantize_state(state, num_buckets=10): - """ - 量化状态向量。 - 参数: - state: 要量化的状态向量。 - num_buckets: 量化的桶数。 - 返回: - 量化后的状态向量的哈希值。 - """ - # 使用np.digitize将状态向量的每个维度值映射到num_buckets个桶中 - quantized_state = np.digitize(state, bins=np.linspace(0, 1, num=num_buckets)) - # 使用更稳定的哈希函数 - quantized_state_bytes = quantized_state.tobytes() - hash_object = hashlib.sha256(quantized_state_bytes) - return hash_object.hexdigest() - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - self.num_layers = config.num_layers - - - all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, latent_state_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - # self.act_embedder = ActEmbedder( - # max_blocks=config.max_blocks, - # block_masks=[act_tokens_pattern], - # embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim)]) - # ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - # self.head_observations = Head( - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - # ) - # ) - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.BatchNorm1d(config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - ) - ) - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - # TODO: policy init : 3 - # Locate the last linear layer and initialize its weights and biases to 0. - # for _, layer in enumerate(reversed(self.head_policy.head_module)): - # if isinstance(layer, nn.Linear): - # nn.init.zeros_(layer.weight) - # nn.init.zeros_(layer.bias) - # break - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - # NOTE - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.config.max_tokens) - - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - self.length3_context_cnt = 0 - self.length2_context_cnt = 0 - self.root_hit_cnt = 0 - self.root_total_query_cnt = 0 - - - - def __repr__(self) -> str: - return "world_model" - - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False, kvcache_independent=False) -> WorldModelOutput: - - if kvcache_independent: - prev_steps = 0 if past_keys_values is None else [past_kv.size for past_kv in past_keys_values] - prev_steps = torch.tensor(prev_steps, device=self.device) - # 我们需要为每个样本生成一个序列的步骤indices,然后获取它们的位置嵌入 - # 首先扩展prev_steps至(num_steps, batch_size),这里num_steps=1 - # prev_steps = prev_steps.unsqueeze(0) - - else: - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - if len(obs_embeddings.shape)==2: - obs_embeddings = obs_embeddings.unsqueeze(1) - num_steps = obs_embeddings.size(1) # (B, T, E) - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device) - - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = obs_embeddings + position_embeddings - else: - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - num_steps = act_tokens.size(1) # (B, T) - act_embeddings = self.act_embedding_table(act_tokens) - - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device) - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = act_embeddings + position_embeddings - else: - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - if kvcache_independent: - x = [] - for k, past_kv in enumerate(past_keys_values): - x.append(self.transformer(sequences[k].unsqueeze(0), past_kv)) - # self.keys_values_wm_list[k] = past_kv # NOTE: todo - x = torch.cat(x, dim=0) - - # TODO: 在collect时,是一步一步的 obs act 传入的 - # prev_steps = prev_steps//1 - - else: - x = self.transformer(sequences, past_keys_values) - - # print('transformer forward done') - - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits, 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - self.keys_values_wm_size_list = [1 for i in range(n)] - - # 复制单个环境对应的 keys_values_wm 并存储 - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - self.root_total_query_cnt += 1 - if cache_key not in self.past_keys_values_cache: - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - else: - self.root_hit_cnt += 1 - root_hit_ratio = self.root_hit_cnt / self.root_total_query_cnt - print('root_total_query_cnt:', self.root_total_query_cnt) - print(f'root_hit_ratio:{root_hit_ratio}') - print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}') - if self.past_keys_values_cache[cache_key].size>1: - print(f'=='*20) - print(f'NOTE: root_hit find size > 1') - print(f'=='*20) - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - elif n > self.env_num and n != int(256) and buffer_action is not None: - # train时计算target value - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor, reset_indices=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - if reset_indices is None: - self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - else: - for i in reset_indices: - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state[i].unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False, kvcache_independent=False) - self.keys_values_wm_list[i] = self.keys_values_wm_single_env - self.keys_values_wm_size_list[i] = 1 - return None - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - if isinstance(obs_act_dict, dict): - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # TODO - # outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - """ - 假设env_num=8 - 8个环境的kv_cache单独存储与寻找,都存储在一个dict中,在recurrent_inference时, - 由于不同环境找到的kv_cache的size不同,先根据最大size对kv_cache在前部补零,然后组成batch_size的kv_cache - 其内部也是通过batch执行transformer forward的推理 - """ - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - latest_state = state_action_history[-1][0] - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - for i in range(ready_env_num): - self.total_query_count += 1 - state_single_env = latest_state[i] # 获取单个环境的 latent state - cache_key = quantize_state(state_single_env) # 使用量化后的状态计算哈希值 - matched_value = self.past_keys_values_cache.get(cache_key) # 检索缓存值 - if matched_value is not None: - # 如果找到匹配的值,将其添加到列表中 - self.hit_count += 1 - # 这里需要deepcopy因为在transformer的forward中会原地修改matched_value - self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - self.keys_values_wm_size_list.append(matched_value.size) - else: - # use zero reset - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - self.keys_values_wm_list.append(self.keys_values_wm_single_env) - self.keys_values_wm_size_list.append(1) - - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - output_sequence, latent_state = [], [] - - - # reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens] - # self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices) - - action = state_action_history[-1][-1] - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - # print(self.keys_values_wm_size_list) - # 获取self.keys_values_wm_size_list的最小值min_size - min_size = min(self.keys_values_wm_size_list) - if min_size >= self.config.max_tokens - 5: - self.length3_context_cnt += len(self.keys_values_wm_size_list) - if min_size >= 3: - self.length2_context_cnt += len(self.keys_values_wm_size_list) - # if max(self.keys_values_wm_size_list) == 7: - # print('max(self.keys_values_wm_size_list) == 7') - # if self.total_query_count>0 and self.total_query_count%1==0: - if self.total_query_count>0 and self.total_query_count%10000==0: - self.hit_freq = self.hit_count/(self.total_query_count) - # print('hit_freq:', self.hit_freq) - # print('hit_count:', self.hit_count) - print('total_query_count:', self.total_query_count) - # 如果总查询次数大于0,计算并打印cnt的比率 - length3_context_cnt_ratio = self.length3_context_cnt / self.total_query_count - print('>=3 node context_cnt:', self.length3_context_cnt) - print('>=3 node context_cnt_ratio:', length3_context_cnt_ratio) - length2_context_cnt_ratio = self.length2_context_cnt / self.total_query_count - print('>=2 node context_cnt_ratio:', length2_context_cnt_ratio) - print('>=2 node context_cnt:', self.length2_context_cnt) - # print(self.keys_values_wm_size_list) - - for layer in range(self.num_layers): - # 每层的k和v缓存列表 - kv_cache_k_list = [] - kv_cache_v_list = [] - - for idx, keys_values in enumerate(self.keys_values_wm_list): - # 获取当前层的k和v缓存 - k_cache = keys_values[layer]._k_cache._cache - v_cache = keys_values[layer]._v_cache._cache - - # 获取当前缓存的有效尺寸 - effective_size = self.keys_values_wm_size_list[idx] - # 计算需要截去的尺寸 - trim_size = effective_size - min_size if effective_size > min_size else 0 - - # 如果需要截去部分,则截去前面的(trim_size)步 - if trim_size > 0: - k_cache_trimmed = k_cache[:, :, trim_size:, :] - v_cache_trimmed = v_cache[:, :, trim_size:, :] - # 在第三维之后补零 - k_cache_padded = F.pad(k_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) - v_cache_padded = F.pad(v_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) - else: - k_cache_padded = k_cache - v_cache_padded = v_cache - - # 将处理过的缓存添加到列表中 - kv_cache_k_list.append(k_cache_padded) - kv_cache_v_list.append(v_cache_padded) - - # 使用torch.stack()将列表中的缓存堆叠起来,形成新的缓存 - self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - - # 更新缓存的尺寸为min_size - self.keys_values_wm._keys_values[layer]._k_cache._size = min_size - self.keys_values_wm._keys_values[layer]._v_cache._size = min_size - # del self.keys_values_wm_list - - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+1 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - # if len(token.shape) != 2: - # token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - if len(token.shape) != 3: - token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) - latent_state.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory - del self.latent_state - self.latent_state = torch.cat(latent_state, dim=1) # (B, K) - - # # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - for i in range(self.latent_state.size(0)): # 遍历每个环境 - state_single_env = self.latent_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 - # 复制单个环境对应的 keys_values_wm 并存储 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # 比较并存储大小较大的 cache - if cache_key in self.past_keys_values_cache: - existing_kvcache = self.past_keys_values_cache[cache_key] - # 判断现有cache和新cache中的size是否有不同 - if self.keys_values_wm_single_env._keys_values[0]._k_cache._size > existing_kvcache.size and self.keys_values_wm_single_env._keys_values[0]._k_cache._size < self.config.max_tokens-1: - # 具有size为self.config.max_tokens-1 不用存,因为使用达到最大size的kv_cache,而这会导致reset,从而一直是从零开始的 - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - elif self.keys_values_wm_single_env._keys_values[0]._k_cache._size < self.config.max_tokens-1: - # 具有size为self.config.max_tokens-1 不用存,因为使用达到最大size的kv_cache,而这会导致reset,从而一直是从零开始的 - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - - # del self.keys_values_wm - # outputs_wm.logits_policy, outputs_wm.logits_value - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - del popped_kv_cache # 不要这一行 - # print('len(self.past_keys_values_cache) > self.max_cache_size') - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs - - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_envnum1_kv-latent-1-env.py b/lzero/model/gpt_models/world_model_envnum1_kv-latent-1-env.py deleted file mode 100644 index 15b988f5f..000000000 --- a/lzero/model/gpt_models/world_model_envnum1_kv-latent-1-env.py +++ /dev/null @@ -1,1009 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - - - all_but_last_obs_tokens_pattern = torch.ones(config.tokens_per_block) - all_but_last_obs_tokens_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - obs_tokens_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, obs_tokens_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - # self.act_embedder = ActEmbedder( - # max_blocks=config.max_blocks, - # block_masks=[act_tokens_pattern], - # embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim)]) - # ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - # self.head_observations = Head( - # max_blocks=config.max_blocks, - # block_mask=all_but_last_obs_tokens_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - # ) - # ) - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_obs_tokens_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.BatchNorm1d(config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - ) - ) - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=obs_tokens_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - # TODO: policy init : 3 - # Locate the last linear layer and initialize its weights and biases to 0. - # for _, layer in enumerate(reversed(self.head_policy.head_module)): - # if isinstance(layer, nn.Linear): - # nn.init.zeros_(layer.weight) - # nn.init.zeros_(layer.bias) - # break - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - # NOTE - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=8, max_tokens=self.config.max_tokens) - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - - - - def __repr__(self) -> str: - return "world_model" - - - # def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None, - # is_root=False) -> WorldModelOutput: - # def forward(self, obs_embeddings, act_tokens, past_keys_values: Optional[KeysValues] = None, - # is_root=False) -> WorldModelOutput: - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False) -> WorldModelOutput: - - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - # sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device)) - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - num_steps = obs_embeddings.size(1) # (B, T, E) - # if prev_steps>0: - # prev_steps = prev_steps+1 # TODO: NOTE: 在collect的每一步,执行init_infer时,不reset kv_cache - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - num_steps = act_tokens.size(1) # (B, T) - # act_embeddings = self.act_embedder(act_tokens, num_steps, prev_steps) - act_embeddings = self.act_embedding_table(act_tokens) - - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - x = self.transformer(sequences, past_keys_values) - # print('transformer forward done') - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_obs_tokens().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_obs_tokens(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.obs_tokens) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.obs_tokens.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.refresh_keys_values_with_initial_obs_tokens_for_init_infer_v2(obs_embeddings, buffer_action) - self.obs_tokens = obs_embeddings - - return outputs_wm, self.obs_tokens - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_obs_tokens_for_init_infer_v2(self, obs_tokens: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = obs_tokens.shape - - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': obs_tokens}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - self.total_query_count += 1 - # Compute the hash of latest_state - hash_latest_state = hash(obs_tokens.detach().cpu().numpy()) - matched_value = self.past_keys_values_cache.get(hash_latest_state) - if matched_value is not None: - self.keys_values_wm_find = copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda') ) - self.hit_count += 1 - # self.total_query_count += 1 - # print('recurrent_inference:find matched_value!') - # NOTE: very important, 相当于policy value由单步计算得到,往后的推理,基于context - # TODO: policy value也从缓存中找 - self.keys_values_wm = self.keys_values_wm_find - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': obs_tokens}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - obs_tokens = obs_tokens.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # obs_tokens = obs_tokens.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - obs_tokens = obs_tokens[:, :-1, :] - # obs_tokens = obs_tokens.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(obs_tokens.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = obs_tokens - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - # NOTE: should_preprocess=True is important - # obs_tokens = self.tokenizer.encode(observations, should_preprocess=True).tokens # (B, C, H, W) -> (B, K) - # _, num_observations_tokens = obs_tokens.shape - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - # num_observations_tokens = obs_embeddings.shape[1] - - # if self.num_observations_tokens is None: - # self._num_observations_tokens = num_observations_tokens - - # outputs_wm = self.refresh_keys_values_with_initial_obs_tokens_for_init_infer(obs_tokens, buffer_action) - outputs_wm = self.refresh_keys_values_with_initial_obs_tokens_for_init_infer(obs_embeddings, buffer_action) - - self.obs_tokens = obs_embeddings - - # return outputs_wm, self.decode_obs_tokens(), self.obs_tokens - return outputs_wm, self.obs_tokens - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_obs_tokens_for_init_infer(self, obs_tokens: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = obs_tokens.shape - # assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - - if n <= self.env_num: - # Compute the hash of obs_tokens - # cache_key = hash(obs_tokens.detach().cpu().numpy()) - # # Try to get the value associated with the hash of latest_state - # matched_value = self.past_keys_values_cache.get(cache_key) - # if matched_value is not None: - # # If a matching value is found, do something with it - # self.keys_values_wm = copy.deepcopy(matched_value) - # print('init inference: find matched_value!') - # else: - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # # print('init inference: not find matched_value! reset!') - - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': obs_tokens}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': obs_tokens}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - # elif n > self.env_num and n != int(256) and buffer_action is not None: - # # transformer只能unroll 5步 - # # TODO: n=256 means train tokenizer - # # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # # [192, 16, 64] -> [32, 6, 16, 64] - # obs_tokens = obs_tokens.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - # buffer_action = torch.from_numpy(buffer_action).to(obs_tokens.device) - # act_tokens = rearrange(buffer_action, 'b l -> b l 1') - # # 将5步动作的最后一步,重复一次,以拼接为6步的动作 - # act_tokens = torch.cat((act_tokens, act_tokens[:, -1:, :]), dim=1) - # obs_embeddings = obs_tokens - # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # # Reshape your tensors - # # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - # outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - # outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # obs_tokens = obs_tokens.reshape(32, 6, num_observations_tokens) # (BL, K) - # obs_tokens = obs_tokens.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # obs_tokens = obs_tokens.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - obs_tokens = obs_tokens.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # obs_tokens = obs_tokens.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - obs_tokens = obs_tokens[:, :-1, :] - # obs_tokens = obs_tokens.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(obs_tokens.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = obs_tokens - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_obs_tokens(self, obs_tokens: torch.LongTensor) -> torch.FloatTensor: - n, num_observations_tokens, _ = obs_tokens.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.world_model.transformer.generate_empty_keys_values(n=n, max_tokens=self.world_model.config.max_tokens) - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - obs_embeddings_or_act_tokens = {'obs_embeddings': obs_tokens} - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - - # return outputs_wm.output_sequence # (B, K, E) - return outputs_wm - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - - if len(obs[0].shape) == 3: - # obs is a 3-dimensional image, for atari - pass - # elif len(obs[0].shape) == 1: - # # TODO(): for cartpole, 4 -> 4,64,64 - # # obs is a 1-dimensional vector - # original_shape = list(obs.shape) - # desired_shape = original_shape + [64, 64] - # expanded_observations = obs.unsqueeze(-1).unsqueeze(-1) - # expanded_observations = expanded_observations.expand(*desired_shape) - # obs = expanded_observations - - # obs_act_dict['obs'] = obs - - # for cartpole, 4 -> 3,64,64 - # obs is a 1-dimensional vector - # original_shape = list(obs.shape) - # desired_shape = original_shape[:-1] + [3, 64, 64] # 修改最后一个维度为3,然后添加64和64 - # repeated_observations = obs.repeat(1, int(3*64*64/original_shape[-1])) # 将最后一个维度复制到3,64,64 - # obs = repeated_observations.view(*desired_shape) # 重新调整形状到3,64,64 - - - # self.hit_count = 0 - # self.total_query_count = 0 - - outputs_wm, obs_tokens = self.reset_from_initial_observations_v2(obs_act_dict) # TODO - # outputs_wm, obs_tokens = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - - if self.keys_values_wm.size > 0: - # Depending on the shape of obs_tokens, create a cache key and store a deep copy of keys_values_wm - # if obs_tokens.shape[0] == 1: - # # This branch will be executed only when env_num=1 - # # cache_key = hash(obs_tokens.squeeze(0).detach().cpu().numpy()) - # cache_key = hash(obs_tokens.detach().cpu().numpy()) - # self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu')) - # elif obs_tokens.shape[0] == self.env_num: - # elif obs_tokens.shape[0] > self.env_num: - # elif obs_tokens.shape[0] > 1 and obs_tokens.shape[0] <= self.env_num: - # This branch will be executed only when env_num=8 - cache_key = hash(obs_tokens.detach().cpu().numpy()) - # Store the KV_cache for all 8 samples together - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu')) - - # return outputs_wm.output_sequence, outputs_wm.logits_observations, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - return outputs_wm.output_sequence, obs_tokens, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - - """ - past-kv-dict-batch envnum8 latest multi-step - fix init infer - 把8个样本的self.keys_values_wm 看做一个整体来寻找 - - TODO:很多时候都是执行的refresh_keys_values_with_initial_obs_tokens,导致没有充分利用序列建模能力? - """ - - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - # if self.total_query_count>0: - # self.hit_freq = self.hit_count/self.total_query_count - # print('hit_freq:', self.hit_freq) - # print('hit_count:', self.hit_count) - # print('total_query_count:', self.total_query_count) - - - latest_state = state_action_history[-1][0] - - # Compute the hash of latest_state - hash_latest_state = hash(latest_state) - - # Try to get the value associated with the hash of latest_state - matched_value = self.past_keys_values_cache.get(hash_latest_state) - self.total_query_count += 1 - if matched_value is not None: - # If a matching value is found, do something with it - # self.keys_values_wm = copy.deepcopy(matched_value) - self.keys_values_wm = copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda') ) - self.hit_count += 1 - # print('recurrent_inference:find matched_value!') - # TODO:####### - else: - # If no matching value is found, handle the case accordingly - # NOTE: very important - _ = self.refresh_keys_values_with_initial_obs_tokens(torch.tensor(latest_state, dtype=torch.float32).to(self.device)) - # Depending on the shape of obs_tokens, create a cache key and store a deep copy of keys_values_wm - # This branch will be executed only when env_num=1 - # cache_key = hash(latest_state.squeeze(0)) - self.past_keys_values_cache[hash_latest_state] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu')) - # print('recurrent_inference:not find matched_value!') - - - assert self.keys_values_wm is not None and self.num_observations_tokens is not None - - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - - output_sequence, obs_tokens = [], [] - - if self.keys_values_wm.size + num_passes > self.config.max_tokens: - del self.keys_values_wm # TODO - # TODO: the impact - _ = self.refresh_keys_values_with_initial_obs_tokens(torch.tensor(latest_state, dtype=torch.float32).to(self.device)) - # Depending on the shape of obs_tokens, create a cache key and store a deep copy of keys_values_wm - self.past_keys_values_cache[hash(latest_state)] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu')) - - # if self.keys_values_wm.size>5: - # print('debug self.keys_values_wm.size ') - - # TODO - action = state_action_history[-1][-1] - - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+16 - - # obs is in token level - # act_token num_steps=1, prev_steps=16 - # obs_token_0 num_steps=1, prev_steps=17 - # obs_token_1 num_steps=1, prev_steps=18 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False) - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - # reward = Categorical(logits=outputs_wm.logits_rewards).sample().float().cpu().numpy().reshape(-1) - 1 # (B,) - # done = Categorical(logits=outputs_wm.logits_ends).sample().cpu().numpy().astype(bool).reshape(-1) # (B,) - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - - if len(token.shape) != 2: - token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - obs_tokens.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.obs_tokens, delete the old one to free memory - del self.obs_tokens - self.obs_tokens = torch.cat(obs_tokens, dim=1) # (B, K) - - # obs = self.decode_obs_tokens() if should_predict_next_obs else None - - # cache_key = hash(self.obs_tokens.detach().cpu().numpy()) - cache_key = hash(self.obs_tokens.detach().cpu().numpy()) - - # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu')) - # outputs_wm.logits_policy, outputs_wm.logits_value - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - # self.past_keys_values_cache.popitem(last=False) # Removes the earliest inserted item - # popitem返回一个键值对,其中第二个元素是值 - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - # 如果popped_kv_cache是一个包含张量或复杂对象的容器,您可能需要进一步删除这些对象 - # 例如: - del popped_kv_cache # 不要这一行 - # torch.cuda.empty_cache() # 请注意,频繁调用可能会影响性能, 先del反而清除不掉占用的2MB缓存 - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.obs_tokens, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - - # if len(batch['observations'][0, 0].shape) == 3: - # # obs is a 3-dimensional image - # pass - - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_envnum8_kv-latent-1-env_fix.py b/lzero/model/gpt_models/world_model_envnum8_kv-latent-1-env_fix.py deleted file mode 100644 index 48f1c5a98..000000000 --- a/lzero/model/gpt_models/world_model_envnum8_kv-latent-1-env_fix.py +++ /dev/null @@ -1,1014 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - self.num_layers = config.num_layers - - - all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, latent_state_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - # self.act_embedder = ActEmbedder( - # max_blocks=config.max_blocks, - # block_masks=[act_tokens_pattern], - # embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim)]) - # ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - # self.head_observations = Head( - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - # ) - # ) - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.BatchNorm1d(config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - - ) - ) - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - # TODO: policy init : 3 - # Locate the last linear layer and initialize its weights and biases to 0. - # for _, layer in enumerate(reversed(self.head_policy.head_module)): - # if isinstance(layer, nn.Linear): - # nn.init.zeros_(layer.weight) - # nn.init.zeros_(layer.bias) - # break - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - # NOTE - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=8, max_tokens=self.config.max_tokens) - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - - - - def __repr__(self) -> str: - return "world_model" - - - # def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None, - # is_root=False) -> WorldModelOutput: - # def forward(self, obs_embeddings, act_tokens, past_keys_values: Optional[KeysValues] = None, - # is_root=False) -> WorldModelOutput: - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False) -> WorldModelOutput: - - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - # sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device)) - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - num_steps = obs_embeddings.size(1) # (B, T, E) - # if prev_steps>0: - # prev_steps = prev_steps+1 # TODO: NOTE: 在collect的每一步,执行init_infer时,不reset kv_cache - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - num_steps = act_tokens.size(1) # (B, T) - # act_embeddings = self.act_embedder(act_tokens, num_steps, prev_steps) - act_embeddings = self.act_embedding_table(act_tokens) - - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - x = self.transformer(sequences, past_keys_values) - # print('transformer forward done') - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - self.total_query_count += 1 - # Compute the hash of latest_state - latest_state = latent_state.detach().cpu().numpy() - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - keys_values_wm_list = [] - for i in range(ready_env_num): - state_single_env = latest_state[i] # 获取单个环境的 latent state - hash_latest_state = hash(state_single_env) # 计算哈希值 - matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值 - if matched_value is not None: - self.hit_count += 1 - # 如果找到匹配的值,将其添加到列表中 - keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - else: - # use zero - keys_values_wm_list.append(self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)) - - # self.keys_values_wm = keys_values_wm_list - - for layer in range(self.num_layers): - kv_cache_k_list = [] - kv_cache_v_list = [] - for keys_values in keys_values_wm_list: - kv_cache_k_list.append(keys_values[layer]._k_cache._cache) - kv_cache_v_list.append(keys_values[layer]._v_cache._cache) - self.keys_values_wm[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - self.keys_values_wm[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - # NOTE: should_preprocess=True is important - # latent_state = self.tokenizer.encode(observations, should_preprocess=True).tokens # (B, C, H, W) -> (B, K) - # _, num_observations_tokens = latent_state.shape - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - # num_observations_tokens = obs_embeddings.shape[1] - - # if self.num_observations_tokens is None: - # self._num_observations_tokens = num_observations_tokens - - # outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(latent_state, buffer_action) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - - self.latent_state = obs_embeddings - - # return outputs_wm, self.decode_latent_state(), self.latent_state - return outputs_wm, self.latent_state - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - # assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - - if n <= self.env_num: - # Compute the hash of latent_state - # cache_key = hash(latent_state.detach().cpu().numpy()) - # # Try to get the value associated with the hash of latest_state - # matched_value = self.past_keys_values_cache.get(cache_key) - # if matched_value is not None: - # # If a matching value is found, do something with it - # self.keys_values_wm = copy.deepcopy(matched_value) - # print('init inference: find matched_value!') - # else: - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # # print('init inference: not find matched_value! reset!') - - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - # elif n > self.env_num and n != int(256) and buffer_action is not None: - # # transformer只能unroll 5步 - # # TODO: n=256 means train tokenizer - # # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # # [192, 16, 64] -> [32, 6, 16, 64] - # latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - # buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - # act_tokens = rearrange(buffer_action, 'b l -> b l 1') - # # 将5步动作的最后一步,重复一次,以拼接为6步的动作 - # act_tokens = torch.cat((act_tokens, act_tokens[:, -1:, :]), dim=1) - # obs_embeddings = latent_state - # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # # Reshape your tensors - # # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - # outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - # outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.world_model.transformer.generate_empty_keys_values(n=n, max_tokens=self.world_model.config.max_tokens) - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - obs_embeddings_or_act_tokens = {'obs_embeddings': latent_state} - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - - # return outputs_wm.output_sequence # (B, K, E) - return outputs_wm - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - - if len(obs[0].shape) == 3: - # obs is a 3-dimensional image, for atari - pass - # elif len(obs[0].shape) == 1: - # # TODO(): for cartpole, 4 -> 4,64,64 - # # obs is a 1-dimensional vector - # original_shape = list(obs.shape) - # desired_shape = original_shape + [64, 64] - # expanded_observations = obs.unsqueeze(-1).unsqueeze(-1) - # expanded_observations = expanded_observations.expand(*desired_shape) - # obs = expanded_observations - - # obs_act_dict['obs'] = obs - - # for cartpole, 4 -> 3,64,64 - # obs is a 1-dimensional vector - # original_shape = list(obs.shape) - # desired_shape = original_shape[:-1] + [3, 64, 64] # 修改最后一个维度为3,然后添加64和64 - # repeated_observations = obs.repeat(1, int(3*64*64/original_shape[-1])) # 将最后一个维度复制到3,64,64 - # obs = repeated_observations.view(*desired_shape) # 重新调整形状到3,64,64 - - - - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # TODO - # outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - - """ - past-kv-dict-batch envnum8 latest multi-step - fix init infer - 把8个样本的self.keys_values_wm 看做一个整体来寻找 - - TODO:很多时候都是执行的refresh_keys_values_with_initial_latent_state,导致没有充分利用序列建模能力? - """ - - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - # if self.total_query_count>0: - # self.hit_freq = self.hit_count/self.total_query_count - # print('hit_freq:', self.hit_freq) - # print('hit_count:', self.hit_count) - # print('total_query_count:', self.total_query_count) - - self.total_query_count += 1 - latest_state = state_action_history[-1][0] - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - keys_values_wm_list = [] - for i in range(ready_env_num): - state_single_env = latest_state[i] # 获取单个环境的 latent state - hash_latest_state = hash(state_single_env) # 计算哈希值 - matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值 - if matched_value is not None: - self.hit_count += 1 - # 如果找到匹配的值,将其添加到列表中 - keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - else: - # use zero - keys_values_wm_list.append(self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)) - - # self.keys_values_wm <- keys_values_wm_list - for layer in range(self.num_layers): - kv_cache_k_list = [] - kv_cache_v_list = [] - for keys_values in keys_values_wm_list: - kv_cache_k_list.append(keys_values[layer]._k_cache._cache) - kv_cache_v_list.append(keys_values[layer]._v_cache._cache) - self.keys_values_wm[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - self.keys_values_wm[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - - - assert self.keys_values_wm is not None and self.num_observations_tokens is not None - - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - - output_sequence, latent_state = [], [] - - if self.keys_values_wm.size + num_passes > self.config.max_tokens: - del self.keys_values_wm # TODO - # TODO: the impact - _ = self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device)) - # Depending on the shape of latent_state, create a cache key and store a deep copy of keys_values_wm - self.past_keys_values_cache[hash(latest_state)] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu')) - - # if self.keys_values_wm.size>5: - # print('debug self.keys_values_wm.size ') - - # TODO - action = state_action_history[-1][-1] - - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+16 - - # obs is in token level - # act_token num_steps=1, prev_steps=16 - # obs_token_0 num_steps=1, prev_steps=17 - # obs_token_1 num_steps=1, prev_steps=18 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False) - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - # reward = Categorical(logits=outputs_wm.logits_rewards).sample().float().cpu().numpy().reshape(-1) - 1 # (B,) - # done = Categorical(logits=outputs_wm.logits_ends).sample().cpu().numpy().astype(bool).reshape(-1) # (B,) - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - - if len(token.shape) != 2: - token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - latent_state.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory - del self.latent_state - self.latent_state = torch.cat(latent_state, dim=1) # (B, K) - latent_state = self.latent_state - - # cache_key = hash(latent_state.detach().cpu().numpy()) - # # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - # self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu')) - - for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - # 复制单个环境对应的 keys_values_wm 并存储 - keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - keys_values_wm_single_env[0]._k_cache._cache = self.keys_values_wm[0]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - keys_values_wm_single_env[0]._v_cache._cache = self.keys_values_wm[0]._v_cache._cache[i].unsqueeze(0) - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(keys_values_wm_single_env, 'cpu')) - - # outputs_wm.logits_policy, outputs_wm.logits_value - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - del popped_kv_cache # 不要这一行 - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - - # if len(batch['observations'][0, 0].shape) == 3: - # # obs is a 3-dimensional image - # pass - - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_envnum8_kv-latent-1-env_fix2.py b/lzero/model/gpt_models/world_model_envnum8_kv-latent-1-env_fix2.py deleted file mode 100644 index e9c4a1c3d..000000000 --- a/lzero/model/gpt_models/world_model_envnum8_kv-latent-1-env_fix2.py +++ /dev/null @@ -1,1023 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - self.num_layers = config.num_layers - - - all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, latent_state_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - # self.act_embedder = ActEmbedder( - # max_blocks=config.max_blocks, - # block_masks=[act_tokens_pattern], - # embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim)]) - # ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - # self.head_observations = Head( - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - # ) - # ) - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.BatchNorm1d(config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - ) - ) - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - # TODO: policy init : 3 - # Locate the last linear layer and initialize its weights and biases to 0. - # for _, layer in enumerate(reversed(self.head_policy.head_module)): - # if isinstance(layer, nn.Linear): - # nn.init.zeros_(layer.weight) - # nn.init.zeros_(layer.bias) - # break - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - # NOTE - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=8, max_tokens=self.config.max_tokens) - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - - - - def __repr__(self) -> str: - return "world_model" - - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False) -> WorldModelOutput: - - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - # sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device)) - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - num_steps = obs_embeddings.size(1) # (B, T, E) - # if prev_steps>0: - # prev_steps = prev_steps+1 # TODO: NOTE: 在collect的每一步,执行init_infer时,不reset kv_cache - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - num_steps = act_tokens.size(1) # (B, T) - # act_embeddings = self.act_embedder(act_tokens, num_steps, prev_steps) - act_embeddings = self.act_embedding_table(act_tokens) - - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - x = self.transformer(sequences, past_keys_values) - # print('transformer forward done') - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - # Compute the hash of latest_state - latest_state = latent_state.detach().cpu().numpy() - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - keys_values_wm_list = [] - for i in range(ready_env_num): - self.total_query_count += 1 - state_single_env = latest_state[i] # 获取单个环境的 latent state - hash_latest_state = hash(state_single_env) # 计算哈希值 - matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值 - if matched_value is not None: - self.hit_count += 1 - # 如果找到匹配的值,将其添加到列表中 - keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - else: - # use zero - keys_values_wm_list.append(self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)) - - # self.keys_values_wm = keys_values_wm_list - - for layer in range(self.num_layers): - kv_cache_k_list = [] - kv_cache_v_list = [] - for keys_values in keys_values_wm_list: - kv_cache_k_list.append(keys_values[layer]._k_cache._cache) - kv_cache_v_list.append(keys_values[layer]._v_cache._cache) - # self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - # self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - self.keys_values_wm._keys_values[layer].update(torch.stack(kv_cache_k_list, dim=0).squeeze(1), torch.stack(kv_cache_v_list, dim=0).squeeze(1)) - - del kv_cache_k_list - del kv_cache_v_list - - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - # NOTE: should_preprocess=True is important - # latent_state = self.tokenizer.encode(observations, should_preprocess=True).tokens # (B, C, H, W) -> (B, K) - # _, num_observations_tokens = latent_state.shape - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - # num_observations_tokens = obs_embeddings.shape[1] - - # if self.num_observations_tokens is None: - # self._num_observations_tokens = num_observations_tokens - - # outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(latent_state, buffer_action) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - - self.latent_state = obs_embeddings - - # return outputs_wm, self.decode_latent_state(), self.latent_state - return outputs_wm, self.latent_state - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - # assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - - if n <= self.env_num: - # Compute the hash of latent_state - # cache_key = hash(latent_state.detach().cpu().numpy()) - # # Try to get the value associated with the hash of latest_state - # matched_value = self.past_keys_values_cache.get(cache_key) - # if matched_value is not None: - # # If a matching value is found, do something with it - # self.keys_values_wm = copy.deepcopy(matched_value) - # print('init inference: find matched_value!') - # else: - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # # print('init inference: not find matched_value! reset!') - - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - # elif n > self.env_num and n != int(256) and buffer_action is not None: - # # transformer只能unroll 5步 - # # TODO: n=256 means train tokenizer - # # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # # [192, 16, 64] -> [32, 6, 16, 64] - # latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - # buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - # act_tokens = rearrange(buffer_action, 'b l -> b l 1') - # # 将5步动作的最后一步,重复一次,以拼接为6步的动作 - # act_tokens = torch.cat((act_tokens, act_tokens[:, -1:, :]), dim=1) - # obs_embeddings = latent_state - # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # # Reshape your tensors - # # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - # outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - # outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.world_model.transformer.generate_empty_keys_values(n=n, max_tokens=self.world_model.config.max_tokens) - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - obs_embeddings_or_act_tokens = {'obs_embeddings': latent_state} - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - - # return outputs_wm.output_sequence # (B, K, E) - return outputs_wm - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - - if len(obs[0].shape) == 3: - # obs is a 3-dimensional image, for atari - pass - # elif len(obs[0].shape) == 1: - # # TODO(): for cartpole, 4 -> 4,64,64 - # # obs is a 1-dimensional vector - # original_shape = list(obs.shape) - # desired_shape = original_shape + [64, 64] - # expanded_observations = obs.unsqueeze(-1).unsqueeze(-1) - # expanded_observations = expanded_observations.expand(*desired_shape) - # obs = expanded_observations - - # obs_act_dict['obs'] = obs - - # for cartpole, 4 -> 3,64,64 - # obs is a 1-dimensional vector - # original_shape = list(obs.shape) - # desired_shape = original_shape[:-1] + [3, 64, 64] # 修改最后一个维度为3,然后添加64和64 - # repeated_observations = obs.repeat(1, int(3*64*64/original_shape[-1])) # 将最后一个维度复制到3,64,64 - # obs = repeated_observations.view(*desired_shape) # 重新调整形状到3,64,64 - - - - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # TODO - # outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - - """ - past-kv-dict-batch envnum8 latest multi-step - fix init infer - 把8个样本的self.keys_values_wm 看做一个整体来寻找 - - TODO:很多时候都是执行的refresh_keys_values_with_initial_latent_state,导致没有充分利用序列建模能力? - """ - - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - # if self.total_query_count>0: - # self.hit_freq = self.hit_count/(self.total_query_count) - # print('hit_freq:', self.hit_freq) - # print('hit_count:', self.hit_count) - # print('total_query_count:', self.total_query_count) - - latest_state = state_action_history[-1][0] - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - keys_values_wm_list = [] - for i in range(ready_env_num): - self.total_query_count += 1 - state_single_env = latest_state[i] # 获取单个环境的 latent state - hash_latest_state = hash(state_single_env) # 计算哈希值 - matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值 - if matched_value is not None: - self.hit_count += 1 - # 如果找到匹配的值,将其添加到列表中 - keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - else: - # use zero - keys_values_wm_list.append(self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)) - - # self.keys_values_wm <- keys_values_wm_list - for layer in range(self.num_layers): - kv_cache_k_list = [] - kv_cache_v_list = [] - for keys_values in keys_values_wm_list: # batch_size - kv_cache_k_list.append(keys_values[layer]._k_cache._cache) - kv_cache_v_list.append(keys_values[layer]._v_cache._cache) - # self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - # self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - - self.keys_values_wm._keys_values[layer].update(torch.stack(kv_cache_k_list, dim=0).squeeze(1), torch.stack(kv_cache_v_list, dim=0).squeeze(1)) - - del kv_cache_k_list - del kv_cache_v_list - - - assert self.keys_values_wm is not None and self.num_observations_tokens is not None - - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - - output_sequence, latent_state = [], [] - - if self.keys_values_wm.size + num_passes > self.config.max_tokens: - del self.keys_values_wm # TODO - # TODO: the impact - _ = self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device)) - # Depending on the shape of latent_state, create a cache key and store a deep copy of keys_values_wm - # self.past_keys_values_cache[hash(latest_state)] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu')) - - # if self.keys_values_wm.size>5: - # print('debug self.keys_values_wm.size ') - - # TODO - action = state_action_history[-1][-1] - - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+16 - - # obs is in token level - # act_token num_steps=1, prev_steps=16 - # obs_token_0 num_steps=1, prev_steps=17 - # obs_token_1 num_steps=1, prev_steps=18 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False) - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - # reward = Categorical(logits=outputs_wm.logits_rewards).sample().float().cpu().numpy().reshape(-1) - 1 # (B,) - # done = Categorical(logits=outputs_wm.logits_ends).sample().cpu().numpy().astype(bool).reshape(-1) # (B,) - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - - if len(token.shape) != 2: - token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - latent_state.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory - del self.latent_state - self.latent_state = torch.cat(latent_state, dim=1) # (B, K) - latent_state = self.latent_state - - # cache_key = hash(latent_state.detach().cpu().numpy()) - # # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - # self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu')) - - for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - # 复制单个环境对应的 keys_values_wm 并存储 - keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for layer in range(self.num_layers): - # keys_values_wm_single_env[layer]._k_cache._cache = self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - # keys_values_wm_single_env[layer]._v_cache._cache = self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0) - - keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - - - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(keys_values_wm_single_env, 'cpu')) - del keys_values_wm_single_env - - # outputs_wm.logits_policy, outputs_wm.logits_value - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - del popped_kv_cache # 不要这一行 - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - - # if len(batch['observations'][0, 0].shape) == 3: - # # obs is a 3-dimensional image - # pass - - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_envnum8_kv-latent-1-env_fix3.py b/lzero/model/gpt_models/world_model_envnum8_kv-latent-1-env_fix3.py deleted file mode 100644 index bf08f3138..000000000 --- a/lzero/model/gpt_models/world_model_envnum8_kv-latent-1-env_fix3.py +++ /dev/null @@ -1,1014 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - self.num_layers = config.num_layers - - - all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, latent_state_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - # self.act_embedder = ActEmbedder( - # max_blocks=config.max_blocks, - # block_masks=[act_tokens_pattern], - # embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim)]) - # ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - # self.head_observations = Head( - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - # ) - # ) - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.BatchNorm1d(config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - ) - ) - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - # TODO: policy init : 3 - # Locate the last linear layer and initialize its weights and biases to 0. - # for _, layer in enumerate(reversed(self.head_policy.head_module)): - # if isinstance(layer, nn.Linear): - # nn.init.zeros_(layer.weight) - # nn.init.zeros_(layer.bias) - # break - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - # NOTE - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.config.max_tokens) - - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - - - - def __repr__(self) -> str: - return "world_model" - - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False, kvcache_independent=False) -> WorldModelOutput: - - if kvcache_independent: - prev_steps = 0 if past_keys_values is None else [past_kv.size for past_kv in past_keys_values] - prev_steps = torch.tensor(prev_steps, device=self.device) - # 我们需要为每个样本生成一个序列的步骤indices,然后获取它们的位置嵌入 - # 首先扩展prev_steps至(num_steps, batch_size),这里num_steps=1 - # prev_steps = prev_steps.unsqueeze(0) - - else: - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - if len(obs_embeddings.shape)==2: - obs_embeddings = obs_embeddings.unsqueeze(1) - num_steps = obs_embeddings.size(1) # (B, T, E) - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device) - - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = obs_embeddings + position_embeddings - else: - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - num_steps = act_tokens.size(1) # (B, T) - act_embeddings = self.act_embedding_table(act_tokens) - - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device) - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = act_embeddings + position_embeddings - else: - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - if kvcache_independent: - x = [] - for k, past_kv in enumerate(past_keys_values): - x.append(self.transformer(sequences[k].unsqueeze(0), past_kv)) - # self.keys_values_wm_list[k] = past_kv # NOTE: todo - x = torch.cat(x, dim=0) - - # TODO: 在collect时,是一步一步的 obs act 传入的 - # prev_steps = prev_steps//1 - - else: - x = self.transformer(sequences, past_keys_values) - - # print('transformer forward done') - - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits, 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - # Compute the hash of latest_state - # latest_state = latent_state.detach().cpu().numpy() - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - # ready_env_num = latest_state.shape[0] - # keys_values_wm_list = [] - # self.keys_values_wm_size_list = [] - # for i in range(ready_env_num): - # self.total_query_count += 1 - # state_single_env = latest_state[i] # 获取单个环境的 latent state - # hash_latest_state = hash(state_single_env) # 计算哈希值 - # matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值 - # if matched_value is not None: - # self.hit_count += 1 - # # 如果找到匹配的值,将其添加到列表中 - # keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - # self.keys_values_wm_size_list.append(matched_value.size) - - # keys_values_wm_list.append(self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)) - # self.keys_values_wm_size_list.append(0) - # else: - # # use zero - # keys_values_wm_list.append(self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)) - # self.keys_values_wm_size_list.append(0) - # self.keys_values_wm_list = keys_values_wm_list - - # use zero - self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm_list, is_root=False, kvcache_independent=True) - self.keys_values_wm_size_list = [1 for i in range(n)] - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - elif n > self.env_num and n != int(256) and buffer_action is not None: - # train时计算target value - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - # NOTE: should_preprocess=True is important - # latent_state = self.tokenizer.encode(observations, should_preprocess=True).tokens # (B, C, H, W) -> (B, K) - # _, num_observations_tokens = latent_state.shape - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - # num_observations_tokens = obs_embeddings.shape[1] - - # if self.num_observations_tokens is None: - # self._num_observations_tokens = num_observations_tokens - - # outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(latent_state, buffer_action) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - - self.latent_state = obs_embeddings - - # return outputs_wm, self.decode_latent_state(), self.latent_state - return outputs_wm, self.latent_state - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor, reset_indices=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - if reset_indices is None: - self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - else: - for i in reset_indices: - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state[i].unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False, kvcache_independent=False) - self.keys_values_wm_list[i] = self.keys_values_wm_single_env - self.keys_values_wm_size_list[i] = 1 - return None - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - - # if obs.shape[0] < 8: - # print('debug') - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # TODO - # outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - - """ - past-kv-dict-batch envnum8 latest multi-step - fix init infer - 把8个样本的self.keys_values_wm 看做一个整体来寻找 - - TODO:很多时候都是执行的refresh_keys_values_with_initial_latent_state,导致没有充分利用序列建模能力? - """ - - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - # if self.total_query_count>0: - # self.hit_freq = self.hit_count/(self.total_query_count) - # print('hit_freq:', self.hit_freq) - # print('hit_count:', self.hit_count) - # print('total_query_count:', self.total_query_count) - - latest_state = state_action_history[-1][0] - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - for i in range(ready_env_num): - self.total_query_count += 1 - state_single_env = latest_state[i] # 获取单个环境的 latent state - hash_latest_state = hash(state_single_env) # 计算哈希值 - matched_value = self.past_keys_values_cache.get(hash_latest_state) # 检索缓存值 - if matched_value is not None: - self.hit_count += 1 - # 如果找到匹配的值,将其添加到列表中 - keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - self.keys_values_wm_size_list.append(matched_value.size) - else: - # use zero - # keys_values_wm_list.append(self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens)) - # self.keys_values_wm_size_list.append(0) - - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - keys_values_wm_list.append(self.keys_values_wm_single_env) - self.keys_values_wm_size_list.append(1) - - - self.keys_values_wm_list = keys_values_wm_list - - assert self.keys_values_wm is not None and self.num_observations_tokens is not None - - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - - output_sequence, latent_state = [], [] - - reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens] - self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices) - - - action = state_action_history[-1][-1] - - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+16 - - # obs is in token level - # act_token num_steps=1, prev_steps=16 - # obs_token_0 num_steps=1, prev_steps=17 - # obs_token_1 num_steps=1, prev_steps=18 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - # if token.shape[0] < 8: - # print('debug') - - # outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=True) - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm_list, is_root=False, kvcache_independent=True) - - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - # reward = Categorical(logits=outputs_wm.logits_rewards).sample().float().cpu().numpy().reshape(-1) - 1 # (B,) - # done = Categorical(logits=outputs_wm.logits_ends).sample().cpu().numpy().astype(bool).reshape(-1) # (B,) - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - # if len(token.shape) != 2: - # token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - if len(token.shape) != 3: - token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) - latent_state.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory - del self.latent_state - self.latent_state = torch.cat(latent_state, dim=1) # (B, K) - latent_state = self.latent_state - - # cache_key = hash(latent_state.detach().cpu().numpy()) - # # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - # self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm, 'cpu')) - - for i in range(latent_state.shape[0]): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - # 复制单个环境对应的 keys_values_wm 并存储 - # keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - # for layer in range(self.num_layers): - # # keys_values_wm_single_env[layer]._k_cache._cache = self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - # # keys_values_wm_single_env[layer]._v_cache._cache = self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0) - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - - # self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(keys_values_wm_single_env, 'cpu')) - - - # print([self.keys_values_wm_list[i].size for i in range(8)]) - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_list[i], 'cpu')) - - # del keys_values_wm_single_env - - # outputs_wm.logits_policy, outputs_wm.logits_value - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - del popped_kv_cache # 不要这一行 - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - - # if len(batch['observations'][0, 0].shape) == 3: - # # obs is a 3-dimensional image - # pass - - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_forloop.py b/lzero/model/gpt_models/world_model_forloop.py deleted file mode 100644 index 25bd786ab..000000000 --- a/lzero/model/gpt_models/world_model_forloop.py +++ /dev/null @@ -1,1017 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler -from line_profiler import line_profiler -import hashlib -# def quantize_state(state, num_buckets=1000): -def quantize_state(state, num_buckets=15): -# def quantize_state(state, num_buckets=10): - """ - 量化状态向量。 - 参数: - state: 要量化的状态向量。 - num_buckets: 量化的桶数。 - 返回: - 量化后的状态向量的哈希值。 - """ - # 使用np.digitize将状态向量的每个维度值映射到num_buckets个桶中 - quantized_state = np.digitize(state, bins=np.linspace(0, 1, num=num_buckets)) - # 使用更稳定的哈希函数 - quantized_state_bytes = quantized_state.tobytes() - hash_object = hashlib.sha256(quantized_state_bytes) - return hash_object.hexdigest() - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - self.num_layers = config.num_layers - - - all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, latent_state_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - # self.act_embedder = ActEmbedder( - # max_blocks=config.max_blocks, - # block_masks=[act_tokens_pattern], - # embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim)]) - # ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - # self.head_observations = Head( - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - # ) - # ) - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.BatchNorm1d(config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - ) - ) - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - # TODO: policy init : 3 - # Locate the last linear layer and initialize its weights and biases to 0. - # for _, layer in enumerate(reversed(self.head_policy.head_module)): - # if isinstance(layer, nn.Linear): - # nn.init.zeros_(layer.weight) - # nn.init.zeros_(layer.bias) - # break - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - # NOTE - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.config.max_tokens) - - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - self.root_hit_cnt = 0 - self.root_total_query_count = 0 - self.length3_context_cnt = 0 - self.length2_context_cnt = 0 - - - - - def __repr__(self) -> str: - return "world_model" - - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False, kvcache_independent=False) -> WorldModelOutput: - - if kvcache_independent: - prev_steps = 0 if past_keys_values is None else [past_kv.size for past_kv in past_keys_values] - prev_steps = torch.tensor(prev_steps, device=self.device) - # 我们需要为每个样本生成一个序列的步骤indices,然后获取它们的位置嵌入 - # 首先扩展prev_steps至(num_steps, batch_size),这里num_steps=1 - # prev_steps = prev_steps.unsqueeze(0) - - else: - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - if len(obs_embeddings.shape)==2: - obs_embeddings = obs_embeddings.unsqueeze(1) - num_steps = obs_embeddings.size(1) # (B, T, E) - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device) - - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = obs_embeddings + position_embeddings - else: - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - num_steps = act_tokens.size(1) # (B, T) - act_embeddings = self.act_embedding_table(act_tokens) - - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device) - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = act_embeddings + position_embeddings - else: - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - if kvcache_independent: - x = [] - for k, past_kv in enumerate(past_keys_values): - x.append(self.transformer(sequences[k].unsqueeze(0), past_kv)) - # self.keys_values_wm_list[k] = past_kv # NOTE: todo - x = torch.cat(x, dim=0) - - # TODO: 在collect时,是一步一步的 obs act 传入的 - # prev_steps = prev_steps//1 - - else: - x = self.transformer(sequences, past_keys_values) - - # print('transformer forward done') - - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits, 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - self.keys_values_wm_size_list = [1 for i in range(n)] - - # 复制单个环境对应的 keys_values_wm 并存储 - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - self.root_total_query_count += 1 - if cache_key not in self.past_keys_values_cache: - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - else: - self.root_hit_cnt += 1 - root_hit_ratio = self.root_hit_cnt / self.root_total_query_count - print('root_total_query_count:', self.root_total_query_count) - print(f'root_hit_ratio:{root_hit_ratio}') - print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}') - if self.past_keys_values_cache[cache_key].size>1: - print(f'=='*20) - print(f'NOTE: root_hit find size > 1') - print(f'=='*20) - - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - elif n > self.env_num and n != int(256) and buffer_action is not None: - # train时计算target value - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - - self.latent_state = obs_embeddings - - # return outputs_wm, self.decode_latent_state(), self.latent_state - return outputs_wm, self.latent_state - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor, reset_indices=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - if reset_indices is None: - self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - else: - for i in reset_indices: - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state[i].unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False, kvcache_independent=False) - self.keys_values_wm_list[i] = self.keys_values_wm_single_env - self.keys_values_wm_size_list[i] = 1 - return None - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - - # if obs.shape[0] < 8: - # print('debug') - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # TODO - # outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - - """ - 假设env_num=8 - 8个环境的kv_cache单独存储与寻找,都存储在一个dict中,在recurrent_inference时, - 其内部也是通过for loop执行transformer forward的推理,推理的结果再组成batch - """ - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - latest_state = state_action_history[-1][0] - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - for i in range(ready_env_num): - self.total_query_count += 1 - state_single_env = latest_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - cache_key = quantize_state(state_single_env) # 使用量化后的状态计算哈希值 - matched_value = self.past_keys_values_cache.get(cache_key) # 检索缓存值 - if matched_value is not None: - self.hit_count += 1 - # 如果找到匹配的值,将其添加到列表中 - keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - self.keys_values_wm_size_list.append(matched_value.size) - else: - # use zero - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - keys_values_wm_list.append(self.keys_values_wm_single_env) - self.keys_values_wm_size_list.append(1) - - - self.length3_context_cnt += len([x for x in self.keys_values_wm_size_list if x >= 5]) - self.length2_context_cnt += len([x for x in self.keys_values_wm_size_list if x >= 3]) - - if self.total_query_count>0 and self.total_query_count%10000==0: - # if self.total_query_count>0 and self.total_query_count%1==0: - self.hit_freq = self.hit_count/(self.total_query_count) - print('hit_freq:', self.hit_freq) - # print('hit_count:', self.hit_count) - print('total_query_count:', self.total_query_count) - print(self.keys_values_wm_size_list) - - length3_context_cnt_ratio = self.length3_context_cnt / self.total_query_count - print('>=3 node context_cnt_ratio:', length3_context_cnt_ratio) - print('>=3 node context_cnt:', self.length3_context_cnt) - - length2_context_cnt_ratio = self.length2_context_cnt / self.total_query_count - print('>=2 node context_cnt_ratio:', length2_context_cnt_ratio) - print('>=2 node context_cnt:', self.length2_context_cnt) - # print(self.keys_values_wm_size_list) - - - self.keys_values_wm_list = keys_values_wm_list - # assert self.keys_values_wm is not None and self.num_observations_tokens is not None - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - output_sequence, latent_state = [], [] - - # reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens] - # self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices) - - action = state_action_history[-1][-1] - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+16 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm_list, is_root=False, kvcache_independent=True) - - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - # reward = Categorical(logits=outputs_wm.logits_rewards).sample().float().cpu().numpy().reshape(-1) - 1 # (B,) - # done = Categorical(logits=outputs_wm.logits_ends).sample().cpu().numpy().astype(bool).reshape(-1) # (B,) - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - # if len(token.shape) != 2: - # token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - if len(token.shape) != 3: - token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) - latent_state.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory - del self.latent_state - self.latent_state = torch.cat(latent_state, dim=1) # (B, K) - - # # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - for i in range(self.latent_state.shape[0]): # 遍历每个环境 - state_single_env = self.latent_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 - # 复制单个环境对应的 keys_values_wm 并存储 - # print([self.keys_values_wm_list[i].size for i in range(8)]) - # self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_list[i], 'cpu')) - - # 比较并存储大小较大的 cache - if cache_key in self.past_keys_values_cache: - existing_kvcache = self.past_keys_values_cache[cache_key] - # 判断现有cache和新cache中的size是否有不同 - if self.keys_values_wm_list[i].size > existing_kvcache.size and self.keys_values_wm_list[i].size < self.config.max_tokens-1: - # 具有size为self.config.max_tokens-1 不用存,因为使用达到最大size的kv_cache,而这会导致reset,从而一直是从零开始的 - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_list[i], 'cpu')) - elif self.keys_values_wm_list[i].size < self.config.max_tokens-1: - # 具有size为self.config.max_tokens-1 不用存,因为使用达到最大size的kv_cache,而这会导致reset,从而一直是从零开始的 - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_list[i], 'cpu')) - - # outputs_wm.logits_policy, outputs_wm.logits_value - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - del popped_kv_cache # 不要这一行 - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs - - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_pad_min_quantize_fixroot.py b/lzero/model/gpt_models/world_model_pad_min_quantize_fixroot.py deleted file mode 100644 index afc2e8a78..000000000 --- a/lzero/model/gpt_models/world_model_pad_min_quantize_fixroot.py +++ /dev/null @@ -1,1162 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler -import hashlib -# def quantize_state(state, num_buckets=1000): -def quantize_state(state, num_buckets=15): -# def quantize_state(state, num_buckets=10): - """ - 量化状态向量。 - 参数: - state: 要量化的状态向量。 - num_buckets: 量化的桶数。 - 返回: - 量化后的状态向量的哈希值。 - """ - # 使用np.digitize将状态向量的每个维度值映射到num_buckets个桶中 - quantized_state = np.digitize(state, bins=np.linspace(0, 1, num=num_buckets)) - # 使用更稳定的哈希函数 - quantized_state_bytes = quantized_state.tobytes() - hash_object = hashlib.sha256(quantized_state_bytes) - return hash_object.hexdigest() - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - self.num_layers = config.num_layers - - - all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, latent_state_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - # self.act_embedder = ActEmbedder( - # max_blocks=config.max_blocks, - # block_masks=[act_tokens_pattern], - # embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim)]) - # ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - # self.head_observations = Head( - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - # ) - # ) - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.BatchNorm1d(config.embed_dim), - # nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - ) - ) - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - # nn.Linear(config.embed_dim, obs_vocab_size) - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - # nn.BatchNorm1d(config.embed_dim), # TODO: 1 - nn.Linear(config.embed_dim, config.embed_dim), - # nn.ReLU(), - nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - # TODO: policy init : 3 - # Locate the last linear layer and initialize its weights and biases to 0. - # for _, layer in enumerate(reversed(self.head_policy.head_module)): - # if isinstance(layer, nn.Linear): - # nn.init.zeros_(layer.weight) - # nn.init.zeros_(layer.bias) - # break - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - # NOTE - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.config.max_tokens) - - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - self.length3_context_cnt = 0 - self.length2_context_cnt = 0 - self.root_hit_cnt = 0 - self.root_total_query_cnt = 0 - - - - def __repr__(self) -> str: - return "world_model" - - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False, kvcache_independent=False) -> WorldModelOutput: - - if kvcache_independent: - prev_steps = 0 if past_keys_values is None else [past_kv.size for past_kv in past_keys_values] - prev_steps = torch.tensor(prev_steps, device=self.device) - # 我们需要为每个样本生成一个序列的步骤indices,然后获取它们的位置嵌入 - # 首先扩展prev_steps至(num_steps, batch_size),这里num_steps=1 - # prev_steps = prev_steps.unsqueeze(0) - - else: - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - if len(obs_embeddings.shape)==2: - obs_embeddings = obs_embeddings.unsqueeze(1) - num_steps = obs_embeddings.size(1) # (B, T, E) - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device) - - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = obs_embeddings + position_embeddings - else: - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - if len(act_tokens.shape)==3: - act_tokens = act_tokens.squeeze(1) - num_steps = act_tokens.size(1) # (B, T) - act_embeddings = self.act_embedding_table(act_tokens) - - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device) - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = act_embeddings + position_embeddings - else: - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - if kvcache_independent: - x = [] - for k, past_kv in enumerate(past_keys_values): - x.append(self.transformer(sequences[k].unsqueeze(0), past_kv)) - # self.keys_values_wm_list[k] = past_kv # NOTE: todo - x = torch.cat(x, dim=0) - - # TODO: 在collect时,是一步一步的 obs act 传入的 - # prev_steps = prev_steps//1 - - else: - x = self.transformer(sequences, past_keys_values) - - # print('transformer forward done') - - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - current_obs = obs_act_dict['current_obs'] - else: - observations = obs_act_dict - buffer_action = None - current_obs = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action, current_obs) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None, current_obs=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - if buffer_action is None: - # MCTS root节点: 需要准确的估计 value, policy_logits, 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - self.keys_values_wm_size_list = [1 for i in range(n)] - - # 复制单个环境对应的 keys_values_wm 并存储 - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - self.root_total_query_cnt += 1 - if cache_key not in self.past_keys_values_cache: - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - else: - self.root_hit_cnt += 1 - root_hit_ratio = self.root_hit_cnt / self.root_total_query_cnt - print('root_total_query_cnt:', self.root_total_query_cnt) - print(f'root_hit_ratio:{root_hit_ratio}') - print(f'root_hit_cnt:{self.root_hit_cnt}') - print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}') - if self.past_keys_values_cache[cache_key].size>1: - print(f'=='*20) - print(f'NOTE: root_hit find size > 1') - print(f'=='*20) - elif current_obs is not None: - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - # ready_env_num = latent_state.shape[0] - ready_env_num = current_obs.shape[0] - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - for i in range(ready_env_num): - state_single_env = latent_state[i] # 获取单个环境的 latent state - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 - matched_value = self.past_keys_values_cache.get(cache_key) # 检索缓存值 - self.root_total_query_cnt += 1 - if matched_value is not None: - # 如果找到匹配的值,将其添加到列表中 - self.root_hit_cnt += 1 - if self.root_total_query_cnt>0 and self.root_total_query_cnt%2000==0: - root_hit_ratio = self.root_hit_cnt / self.root_total_query_cnt - print('root_total_query_cnt:', self.root_total_query_cnt) - print(f'root_hit_ratio:{root_hit_ratio}') - print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}') - if self.past_keys_values_cache[cache_key].size>=7: - print(f'=='*20) - print(f'NOTE: root_hit find size >= 7') - print(f'=='*20) - # 这里需要deepcopy因为在transformer的forward中会原地修改matched_value - self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - self.keys_values_wm_size_list.append(matched_value.size) - else: - # use zero reset - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - # outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - self.keys_values_wm_list.append(self.keys_values_wm_single_env) - self.keys_values_wm_size_list.append(1) - min_size = min(self.keys_values_wm_size_list) - for layer in range(self.num_layers): - # 每层的k和v缓存列表 - kv_cache_k_list = [] - kv_cache_v_list = [] - - for idx, keys_values in enumerate(self.keys_values_wm_list): - # 获取当前层的k和v缓存 - k_cache = keys_values[layer]._k_cache._cache - v_cache = keys_values[layer]._v_cache._cache - - # 获取当前缓存的有效尺寸 - effective_size = self.keys_values_wm_size_list[idx] - # 计算需要截去的尺寸 - trim_size = effective_size - min_size if effective_size > min_size else 0 - - # 如果需要截去部分,则截去前面的(trim_size)步 - if trim_size > 0: - k_cache_trimmed = k_cache[:, :, trim_size:, :] - v_cache_trimmed = v_cache[:, :, trim_size:, :] - # 在第三维之后补零 - k_cache_padded = F.pad(k_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) - v_cache_padded = F.pad(v_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) - else: - k_cache_padded = k_cache - v_cache_padded = v_cache - - # 将处理过的缓存添加到列表中 - kv_cache_k_list.append(k_cache_padded) - kv_cache_v_list.append(v_cache_padded) - - # 使用torch.stack()将列表中的缓存堆叠起来,形成新的缓存 - self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - - # 更新缓存的尺寸为min_size - self.keys_values_wm._keys_values[layer]._k_cache._size = min_size - self.keys_values_wm._keys_values[layer]._v_cache._size = min_size - - buffer_action = buffer_action[:ready_env_num] - buffer_action = torch.from_numpy(np.array(buffer_action)).to(latent_state.device) - act_tokens = buffer_action.unsqueeze(-1) - # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (latent_state, act_tokens)}, past_keys_values=self.keys_values_wm, is_root=False) - outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, is_root=False) - - current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(current_obs, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm, is_root=False) - - # 复制单个环境对应的 keys_values_wm 并存储 - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for i in range(current_obs_embeddings.size(0)): # 遍历每个环境 - state_single_env = current_obs_embeddings[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - # 比较并存储大小较大的 cache - if cache_key in self.past_keys_values_cache: - existing_kvcache = self.past_keys_values_cache[cache_key] - # 判断现有cache和新cache中的size是否有不同 - if self.keys_values_wm_single_env._keys_values[0]._k_cache._size > existing_kvcache.size and self.keys_values_wm_single_env._keys_values[0]._k_cache._size < self.config.max_tokens-1: - # 具有size为self.config.max_tokens-1 不用存,因为使用达到最大size的kv_cache,而这会导致reset,从而一直是从零开始的 - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - elif self.keys_values_wm_single_env._keys_values[0]._k_cache._size < self.config.max_tokens-1: - # 具有size为self.config.max_tokens-1 不用存,因为使用达到最大size的kv_cache,而这会导致reset,从而一直是从零开始的 - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - elif n > self.env_num and n != int(256) and buffer_action is not None: - # train时计算target value - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor, reset_indices=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - if reset_indices is None: - self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - else: - for i in reset_indices: - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state[i].unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False, kvcache_independent=False) - self.keys_values_wm_list[i] = self.keys_values_wm_single_env - self.keys_values_wm_size_list[i] = 1 - return None - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - if isinstance(obs_act_dict, dict): - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # TODO - # outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - """ - 假设env_num=8 - 8个环境的kv_cache单独存储与寻找,都存储在一个dict中,在recurrent_inference时, - 由于不同环境找到的kv_cache的size不同,先根据最大size对kv_cache在前部补零,然后组成batch_size的kv_cache - 其内部也是通过batch执行transformer forward的推理 - """ - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - latest_state = state_action_history[-1][0] - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - for i in range(ready_env_num): - self.total_query_count += 1 - state_single_env = latest_state[i] # 获取单个环境的 latent state - cache_key = quantize_state(state_single_env) # 使用量化后的状态计算哈希值 - matched_value = self.past_keys_values_cache.get(cache_key) # 检索缓存值 - if matched_value is not None: - # 如果找到匹配的值,将其添加到列表中 - self.hit_count += 1 - # 这里需要deepcopy因为在transformer的forward中会原地修改matched_value - self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - self.keys_values_wm_size_list.append(matched_value.size) - else: - # use zero reset - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - self.keys_values_wm_list.append(self.keys_values_wm_single_env) - self.keys_values_wm_size_list.append(1) - - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - output_sequence, latent_state = [], [] - - - # reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens] - # self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices) - - action = state_action_history[-1][-1] - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - # print(self.keys_values_wm_size_list) - # 获取self.keys_values_wm_size_list的最小值min_size - min_size = min(self.keys_values_wm_size_list) - if min_size >= self.config.max_tokens - 5: - self.length3_context_cnt += len(self.keys_values_wm_size_list) - if min_size >= 3: - self.length2_context_cnt += len(self.keys_values_wm_size_list) - # if max(self.keys_values_wm_size_list) == 7: - # print('max(self.keys_values_wm_size_list) == 7') - # if self.total_query_count>0 and self.total_query_count%1==0: - if self.total_query_count>0 and self.total_query_count%10000==0: - self.hit_freq = self.hit_count/(self.total_query_count) - # print('hit_freq:', self.hit_freq) - # print('hit_count:', self.hit_count) - print('total_query_count:', self.total_query_count) - # 如果总查询次数大于0,计算并打印cnt的比率 - length3_context_cnt_ratio = self.length3_context_cnt / self.total_query_count - print('>=3 node context_cnt:', self.length3_context_cnt) - print('>=3 node context_cnt_ratio:', length3_context_cnt_ratio) - length2_context_cnt_ratio = self.length2_context_cnt / self.total_query_count - print('>=2 node context_cnt_ratio:', length2_context_cnt_ratio) - print('>=2 node context_cnt:', self.length2_context_cnt) - # print(self.keys_values_wm_size_list) - - for layer in range(self.num_layers): - # 每层的k和v缓存列表 - kv_cache_k_list = [] - kv_cache_v_list = [] - - for idx, keys_values in enumerate(self.keys_values_wm_list): - # 获取当前层的k和v缓存 - k_cache = keys_values[layer]._k_cache._cache - v_cache = keys_values[layer]._v_cache._cache - # 获取当前缓存的有效尺寸 - effective_size = self.keys_values_wm_size_list[idx] - # 计算需要截去的尺寸 - trim_size = effective_size - min_size if effective_size > min_size else 0 - # 如果需要截去部分,则截去前面的(trim_size)步 - if trim_size > 0: - k_cache_trimmed = k_cache[:, :, trim_size:, :] - v_cache_trimmed = v_cache[:, :, trim_size:, :] - # 在第三维之后补零 - k_cache_padded = F.pad(k_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) - v_cache_padded = F.pad(v_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) - else: - k_cache_padded = k_cache - v_cache_padded = v_cache - # 将处理过的缓存添加到列表中 - kv_cache_k_list.append(k_cache_padded) - kv_cache_v_list.append(v_cache_padded) - - # 使用torch.stack()将列表中的缓存堆叠起来,形成新的缓存 - self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - - # 更新缓存的尺寸为min_size - self.keys_values_wm._keys_values[layer]._k_cache._size = min_size - self.keys_values_wm._keys_values[layer]._v_cache._size = min_size - # del self.keys_values_wm_list - - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+1 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - # if len(token.shape) != 2: - # token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - if len(token.shape) != 3: - token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) - latent_state.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory - del self.latent_state - self.latent_state = torch.cat(latent_state, dim=1) # (B, K) - - # # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - for i in range(self.latent_state.size(0)): # 遍历每个环境 - state_single_env = self.latent_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 - # 复制单个环境对应的 keys_values_wm 并存储 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # 比较并存储大小较大的 cache - if cache_key in self.past_keys_values_cache: - existing_kvcache = self.past_keys_values_cache[cache_key] - # 判断现有cache和新cache中的size是否有不同 - if self.keys_values_wm_single_env._keys_values[0]._k_cache._size > existing_kvcache.size and self.keys_values_wm_single_env._keys_values[0]._k_cache._size < self.config.max_tokens-1: - # 具有size为self.config.max_tokens-1 不用存,因为使用达到最大size的kv_cache,而这会导致reset,从而一直是从零开始的 - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - elif self.keys_values_wm_single_env._keys_values[0]._k_cache._size < self.config.max_tokens-1: - # 具有size为self.config.max_tokens-1 不用存,因为使用达到最大size的kv_cache,而这会导致reset,从而一直是从零开始的 - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - - # del self.keys_values_wm - # outputs_wm.logits_policy, outputs_wm.logits_value - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - del popped_kv_cache # 不要这一行 - # print('len(self.past_keys_values_cache) > self.max_cache_size') - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs - - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/model/gpt_models/world_model_pad_min_quantize_fixroot_v2.py b/lzero/model/gpt_models/world_model_pad_min_quantize_fixroot_v2.py deleted file mode 100644 index 25d38b73d..000000000 --- a/lzero/model/gpt_models/world_model_pad_min_quantize_fixroot_v2.py +++ /dev/null @@ -1,1166 +0,0 @@ -import copy -from dataclasses import dataclass -import random -from typing import Any, Optional, Tuple -from typing import List, Optional, Union -import logging -# 设置日志记录级别为DEBUG -logging.getLogger().setLevel(logging.DEBUG) -from PIL import Image -from einops import rearrange -from einops import rearrange -import gym -from joblib import hash -import numpy as np -import torch -import torch -from torch.distributions.categorical import Categorical -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .kv_caching import KeysValues -from .slicer import Embedder, Head, ActEmbedder -from .tokenizer import Tokenizer -from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights -from ding.torch_utils import to_device -# from memory_profiler import profile -from line_profiler import line_profiler -import hashlib -# def quantize_state(state, num_buckets=1000): -def quantize_state(state, num_buckets=15): -# def quantize_state(state, num_buckets=10): - """ - 量化状态向量。 - 参数: - state: 要量化的状态向量。 - num_buckets: 量化的桶数。 - 返回: - 量化后的状态向量的哈希值。 - """ - # 使用np.digitize将状态向量的每个维度值映射到num_buckets个桶中 - quantized_state = np.digitize(state, bins=np.linspace(0, 1, num=num_buckets)) - # 使用更稳定的哈希函数 - quantized_state_bytes = quantized_state.tobytes() - hash_object = hashlib.sha256(quantized_state_bytes) - return hash_object.hexdigest() - -@dataclass -class WorldModelOutput: - output_sequence: torch.FloatTensor - logits_observations: torch.FloatTensor - logits_rewards: torch.FloatTensor - logits_ends: torch.FloatTensor - - logits_policy: torch.FloatTensor - logits_value: torch.FloatTensor - - -class WorldModel(nn.Module): - def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: TransformerConfig, tokenizer, representation_network=None) -> None: - super().__init__() - - # config.max_tokens = int(2*50) # TODO - - self.tokenizer = tokenizer - self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size - self.config = config - - self.transformer = Transformer(config) - # self.num_observations_tokens = 16 - self.num_observations_tokens = config.tokens_per_block -1 - - self.latent_recon_loss_weight = config.latent_recon_loss_weight - self.perceptual_loss_weight = config.perceptual_loss_weight - - self.device = config.device - self.support_size = config.support_size - self.action_shape = config.action_shape - self.max_cache_size = config.max_cache_size - self.env_num = config.env_num - self.num_layers = config.num_layers - - - all_but_last_latent_state_pattern = torch.ones(config.tokens_per_block) - all_but_last_latent_state_pattern[-2] = 0 # 1,...,0,1 - - act_tokens_pattern = torch.zeros(self.config.tokens_per_block) # 17 - act_tokens_pattern[-1] = 1 # 0,...,0,1 - latent_state_pattern = 1 - act_tokens_pattern # 1,...,1,0 - - # current latent state's policy value - value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - value_policy_tokens_pattern[-2] = 1 # [0,...,1,0] - - # next latent state's policy value - # value_policy_tokens_pattern = torch.zeros(config.tokens_per_block) - # value_policy_tokens_pattern[-1] = 1 # [0,...,0,1] - - obs_per_embdding_dim=config.embed_dim - - self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim) - - - self.embedder = Embedder( - max_blocks=config.max_blocks, - block_masks=[act_tokens_pattern, latent_state_pattern], - embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config.embed_dim), nn.Embedding(obs_vocab_size, config.embed_dim)]) - ) - - self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim) - - self.obs_per_embdding_dim = config.embed_dim # 16*64=1024 - - self.head_rewards = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, self.support_size) - ) - ) - - self.head_ends = Head( - max_blocks=config.max_blocks, - block_mask=act_tokens_pattern, # 0,...,0,1 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, 2) - ) - ) - - self.head_observations_for_root = Head( # TODO - max_blocks=config.max_blocks, - block_mask=latent_state_pattern, # 1,...,1,0 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - nn.BatchNorm1d(config.embed_dim), - nn.ReLU(), - nn.Linear(config.embed_dim, obs_per_embdding_dim) - ) - ) - - ###### TODO: 2层的性能, LeakyReLU->GELU ###### - self.head_observations = Head( # TODO - max_blocks=config.max_blocks, - block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.GELU(), - nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Tanh(), # TODO - nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - ) - ) - self.head_policy = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - head_module=nn.Sequential( # (8, 5, 128) - nn.Linear(config.embed_dim, config.embed_dim), - # nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.GELU(), - nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - ) - ) - self.head_value = Head( - max_blocks=config.max_blocks, - block_mask=value_policy_tokens_pattern, - head_module=nn.Sequential( - nn.Linear(config.embed_dim, config.embed_dim), - # nn.LeakyReLU(negative_slope=0.01), # TODO: 2 - nn.GELU(), - nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - ) - ) - - ###### TODO: 单层的性能 ###### - # self.head_observations = Head( # TODO - # max_blocks=config.max_blocks, - # block_mask=all_but_last_latent_state_pattern, # 1,...,0,1 # https://github.com/eloialonso/iris/issues/19 - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, self.obs_per_embdding_dim), - # nn.Sigmoid(), # 这里添加Sigmoid函数 TODO - # ) - # ) - # self.head_policy = Head( - # max_blocks=config.max_blocks, - # block_mask=value_policy_tokens_pattern, # TODO: value_policy_tokens_pattern # [0,...,1,0] - # head_module=nn.Sequential( # (8, 5, 128) - # nn.Linear(config.embed_dim, self.action_shape) # TODO(pu); action shape - # ) - # ) - # self.head_value = Head( - # max_blocks=config.max_blocks, - # block_mask=value_policy_tokens_pattern, - # head_module=nn.Sequential( - # nn.Linear(config.embed_dim, self.support_size) # TODO(pu): action shape - # ) - # ) - - self.apply(init_weights) - - last_linear_layer_init_zero = True # TODO: is beneficial for convergence speed. - if last_linear_layer_init_zero: - for _, layer in enumerate(reversed(self.head_value.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - if layer.bias is not None: - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_rewards.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - if layer.bias is not None: - nn.init.zeros_(layer.bias) - break - for _, layer in enumerate(reversed(self.head_observations.head_module)): - if isinstance(layer, nn.Linear): - nn.init.zeros_(layer.weight) - # layer.weight.data.fill_(0.5) # TODO:bug - if layer.bias is not None: - nn.init.zeros_(layer.bias) - break - - - import collections - self.past_keys_values_cache = collections.OrderedDict() - self.past_policy_value_cache = collections.OrderedDict() - - # TODO: Transformer更新后应该清除缓存 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.config.max_tokens) - - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - - - if self.num_observations_tokens==16: # k=16 - self.projection_input_dim = 128 - elif self.num_observations_tokens==1: # K=1 - self.projection_input_dim = self.obs_per_embdding_dim# for atari #TODO - # self.projection_input_dim = 256 # for cartpole - - - self.proj_hid = 1024 - self.proj_out = 1024 - self.pred_hid = 512 - self.pred_out = 1024 - activation = nn.ReLU() - self.projection = nn.Sequential( - nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, - nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) - ) - self.prediction_head = nn.Sequential( - nn.Linear(self.proj_out, self.pred_hid), - nn.BatchNorm1d(self.pred_hid), - activation, - nn.Linear(self.pred_hid, self.pred_out), - ) - self.hit_count = 0 - self.total_query_count = 0 - self.length3_context_cnt = 0 - self.length2_context_cnt = 0 - self.root_hit_cnt = 0 - self.root_total_query_cnt = 0 - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - - - def __repr__(self) -> str: - return "world_model" - - # @profile - def forward(self, obs_embeddings_or_act_tokens, past_keys_values: Optional[KeysValues] = None, - is_root=False, kvcache_independent=False) -> WorldModelOutput: - - if kvcache_independent: - prev_steps = 0 if past_keys_values is None else [past_kv.size for past_kv in past_keys_values] - prev_steps = torch.tensor(prev_steps, device=self.device) - # 我们需要为每个样本生成一个序列的步骤indices,然后获取它们的位置嵌入 - # 首先扩展prev_steps至(num_steps, batch_size),这里num_steps=1 - # prev_steps = prev_steps.unsqueeze(0) - - else: - prev_steps = 0 if past_keys_values is None else past_keys_values.size - # print(f'prev_steps:{prev_steps}') - - if 'obs_embeddings' in obs_embeddings_or_act_tokens.keys(): - obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] - if len(obs_embeddings.shape)==2: - obs_embeddings = obs_embeddings.unsqueeze(1) - num_steps = obs_embeddings.size(1) # (B, T, E) - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=obs_embeddings.device) - - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = obs_embeddings + position_embeddings - else: - sequences = obs_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - elif 'act_tokens' in obs_embeddings_or_act_tokens.keys(): - act_tokens = obs_embeddings_or_act_tokens['act_tokens'] - if len(act_tokens.shape)==3: - act_tokens = act_tokens.squeeze(1) - num_steps = act_tokens.size(1) # (B, T) - act_embeddings = self.act_embedding_table(act_tokens) - - if kvcache_independent: - # 生成每个样本的步骤indices - # steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device).unsqueeze(1) - steps_indices = prev_steps + torch.arange(num_steps, device=act_embeddings.device) - # 步骤indices需要被reshape成一维,以便于embedding操作 - # steps_indices = steps_indices.view(-1) - # 获取位置嵌入 - position_embeddings = self.pos_emb(steps_indices) - # 由于我们要将它们加到obs_embeddings上,需要将位置嵌入reshape回(batch_size, num_steps, embedding_dim) - position_embeddings = position_embeddings.view(-1, num_steps, position_embeddings.shape[-1]) - # 现在我们可以将位置嵌入加到obs_embeddings上了 - sequences = act_embeddings + position_embeddings - else: - sequences = act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=act_tokens.device)) - else: - obs_embeddings_and_act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] - # obs_embeddings: (B, L, K=16, E), act_tokens: (B, L, 1) - obs_embeddings, act_tokens = obs_embeddings_and_act_tokens - if len(obs_embeddings.shape)==3: # for batch compute loss - obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) - - num_steps = int(obs_embeddings.size(1)*(obs_embeddings.size(2)+1)) # L(k+1) - # assert num_steps <= self.config.max_tokens - # Rearrange observation embeddings from (B, L, K, E) to (B, L*K, E) - # obs_embeddings = rearrange(obs_embeddings, 'b l k e -> b (l k) e') - - # Generate action embeddings from action tokens - # (B, L, 1) -> (B, L, 1, E) - act_embeddings = self.act_embedding_table(act_tokens) - - # 已知obs_embeddings的维度为 (B, L, K, E), act_embeddings的维度为(B, L, 1, E) 希望得到一个obs_act_embeddings向量的维度为 (B, L(K+1), E) - # 而且让得到的obs_act_embeddings的第2个维度的数据为:obs act, obs, act, ..., obs, act,即 L, 1, L,1 ... 这样的排列顺序。请给出高效的实现,用中文回答 - - B, L, K, E = obs_embeddings.size() - # _, _, _, _ = act_embeddings.size() - - # 初始化一个新的空tensor,用于存放最终的拼接结果 - obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=obs_embeddings.device) - - # 对每一个序列长度L进行循环 - for i in range(L): - # 获取当前时刻的obs和act embeddings - obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) - act = act_embeddings[:, i, 0, :].unsqueeze(1) # Shape: (B, 1, E), 补充维度以便拼接 - - # 交替拼接obs和act - obs_act = torch.cat([obs, act], dim=1) # Shape: (B, K + 1, E) - - # 将结果填充到最终的tensor中 - obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act - - # 确保形状正确无误 - # assert obs_act_embeddings.shape == (B, L * (K + 1), E) - - # Add positional embeddings - sequences = obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=obs_embeddings.device)) - - - # print('transformer forward begin') 函数里面更新了update past_keys_values - if kvcache_independent: - x = [] - for k, past_kv in enumerate(past_keys_values): - x.append(self.transformer(sequences[k].unsqueeze(0), past_kv)) - x = torch.cat(x, dim=0) - - # TODO: 在collect时,是一步一步的 obs act 传入的 - # prev_steps = prev_steps//1 - - else: - x = self.transformer(sequences, past_keys_values) - - # print('transformer forward done') - - - if is_root: - logits_observations = self.head_observations_for_root(x, num_steps=num_steps, prev_steps=prev_steps) - else: - # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps) - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends) - - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: root reward value - return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - - @torch.no_grad() - def render_batch(self) -> List[Image.Image]: - frames = self.decode_latent_state().detach().cpu() - frames = rearrange(frames, 'b c h w -> b h w c').mul(255).numpy().astype(np.uint8) - return [Image.fromarray(frame) for frame in frames] - - # only foe inference now, now is invalid - @torch.no_grad() - def decode_latent_state(self) -> List[Image.Image]: - embedded_tokens = self.tokenizer.embedding(self.latent_state) # (B, K, E) - z = rearrange(embedded_tokens, 'b (h w) e -> b e h w', h=int(np.sqrt(self.num_observations_tokens))) - rec = self.tokenizer.decode(z, should_postprocess=True) # (B, C, H, W) - # TODO: for atari image - return torch.clamp(rec, 0, 1) - # for cartpole obs - # return rec - - - @torch.no_grad() - def render(self): - assert self.latent_state.shape == (1, self.num_observations_tokens) - return self.render_batch()[0] - - @torch.no_grad() - def reset(self) -> torch.FloatTensor: - assert self.env is not None - obs = torchvision.transforms.functional.to_tensor(self.env.reset()).to(self.device).unsqueeze( - 0) # (1, C, H, W) in [0., 1.] - return self.reset_from_initial_observations(obs) - - - @torch.no_grad() - # @profile - def reset_from_initial_observations_v2(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - current_obs = obs_act_dict['current_obs'] - else: - observations = obs_act_dict - buffer_action = None - current_obs = None - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - - if current_obs is not None: - current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(current_obs, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - self.latent_state = current_obs_embeddings - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action, current_obs_embeddings) - else: - self.latent_state = obs_embeddings - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer_v2(obs_embeddings, buffer_action, None) - - - return outputs_wm, self.latent_state - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer_v2(self, latent_state: torch.LongTensor, buffer_action=None, current_obs_embeddings=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - if buffer_action is None: - # MCTS root节点: 需要准确的估计 value, policy_logits, 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - self.keys_values_wm_size_list = [1 for i in range(n)] - - # 复制单个环境对应的 keys_values_wm 并存储 - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - for i in range(latent_state.size(0)): # 遍历每个环境 - state_single_env = latent_state[i] # 获取单个环境的 latent state - # cache_key = hash(state_single_env.detach().cpu().numpy()) # 计算哈希值 - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 - for layer in range(self.num_layers): - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - # keys_values_wm_single_env[layer].update(self.keys_values_wm[layer]._k_cache._cache[i].unsqueeze(0), self.keys_values_wm[layer]._v_cache._cache[i].unsqueeze(0)) - self.root_total_query_cnt += 1 - if cache_key not in self.past_keys_values_cache: - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - else: - self.root_hit_cnt += 1 - root_hit_ratio = self.root_hit_cnt / self.root_total_query_cnt - print('root_total_query_cnt:', self.root_total_query_cnt) - print(f'root_hit_ratio:{root_hit_ratio}') - print(f'root_hit_cnt:{self.root_hit_cnt}') - print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}') - if self.past_keys_values_cache[cache_key].size>1: - print(f'=='*20) - print(f'NOTE: root_hit find size > 1') - print(f'=='*20) - elif current_obs_embeddings is not None: - - if max(buffer_action) == -1: - # first step in one episode - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=8, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm, is_root=False) - - # 复制单个环境对应的 keys_values_wm 并存储 - self.update_cache(current_obs_embeddings) - else: - # self.retrieve_or_generate_kvcache(latent_state, current_obs.shape[0]) - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - # ready_env_num = latent_state.shape[0] - ready_env_num = current_obs_embeddings.shape[0] - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - for i in range(ready_env_num): - state_single_env = latent_state[i] # 获取单个环境的 latent state - quantized_state = state_single_env.detach().cpu().numpy() - cache_key = quantize_state(quantized_state) # 使用量化后的状态计算哈希值 - matched_value = self.past_keys_values_cache.get(cache_key) # 检索缓存值 - self.root_total_query_cnt += 1 - if matched_value is not None: - # 如果找到匹配的值,将其添加到列表中 - self.root_hit_cnt += 1 - if self.root_total_query_cnt>0 and self.root_total_query_cnt%1000==0: - root_hit_ratio = self.root_hit_cnt / self.root_total_query_cnt - print('root_total_query_cnt:', self.root_total_query_cnt) - print(f'root_hit_ratio:{root_hit_ratio}') - print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}') - if self.past_keys_values_cache[cache_key].size>=7: - print(f'=='*20) - print(f'NOTE: root_hit find size >= 7') - print(f'=='*20) - # 这里需要deepcopy因为在transformer的forward中会原地修改matched_value - self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda'))) - self.keys_values_wm_size_list.append(matched_value.size) - else: - # use zero reset - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - # outputs_wm = self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - self.keys_values_wm_list.append(self.keys_values_wm_single_env) - self.keys_values_wm_size_list.append(1) - - # print(f'NOTE: root {self.keys_values_wm_size_list}') - # print(f'=='*20) - # 输入self.keys_values_wm_list,输出为self.keys_values_wm - self.trim_and_pad_kv_cache() - - buffer_action = buffer_action[:ready_env_num] - buffer_action = torch.from_numpy(np.array(buffer_action)).to(latent_state.device) - act_tokens = buffer_action.unsqueeze(-1) - # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (latent_state, act_tokens)}, past_keys_values=self.keys_values_wm, is_root=False) - outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, is_root=False) - - outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, past_keys_values=self.keys_values_wm, is_root=False) - - # 复制单个环境对应的 keys_values_wm 并存储 - self.update_cache(current_obs_embeddings) - - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) - elif n > self.env_num and n != int(256) and buffer_action is not None: - # train时计算target value - # TODO: n=256 means train tokenizer - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - return outputs_wm - - @torch.no_grad() - # @profile - def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor) -> torch.FloatTensor: - if isinstance(obs_act_dict, dict): - # obs_act_dict = {'obs':obs, 'action':action_batch} - observations = obs_act_dict['obs'] - buffer_action = obs_act_dict['action'] - else: - observations = obs_act_dict - buffer_action = None - - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, should_preprocess=True) # (B, C, H, W) -> (B, K, E) - outputs_wm = self.refresh_keys_values_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action) - self.latent_state = obs_embeddings - - return outputs_wm, self.latent_state - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - if n <= self.env_num: - # MCTS root节点: 需要准确的估计 value, policy_logits 或许需要结合context的kv_cache进行更准确的估计,而不是当前的从0开始推理 - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n == int(256): - # TODO: n=256 means train tokenizer, 不需要计算target value - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - # print('init inference: not find matched_value! reset!') - outputs_wm = self.forward({'obs_embeddings': latent_state}, past_keys_values=self.keys_values_wm, is_root=False) # Note: is_root=False - elif n > self.env_num and n != int(256) and buffer_action is not None: - # TODO: n=256 means train tokenizer - # TODO: for n=32*6=192 means 通过unroll 5 steps,计算target value - # latent_state = latent_state.reshape(32, 6, num_observations_tokens) # (BL, K) - # latent_state = latent_state.view(-1, 6, num_observations_tokens) # (BL, K) - - # [192, 16] -> [32, 6, 16] - # latent_state = latent_state.view(buffer_action.shape[0], -1, num_observations_tokens) # (BL, K) for unroll_step=1 - - # [192, 16, 64] -> [32, 6, 16, 64] - latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 - - # latent_state = latent_state.view(-1, self.config.max_blocks+1, num_observations_tokens) # (BL, K) - latent_state = latent_state[:, :-1, :] - # latent_state = latent_state.reshape(32*6, num_observations_tokens) # (BL, K) - buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) - act_tokens = rearrange(buffer_action, 'b l -> b l 1') - - # # 选择每个样本的最后一步 - last_steps = act_tokens[:, -1:, :] # 这将选择最后一列并保持维度不变, 最后一步的target policy/value本身就没有用到 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - act_tokens = torch.cat((act_tokens, last_steps), dim=1) - - # print('init inference: unroll 5 steps!') 17*6=102 17*5=85 - obs_embeddings = latent_state - outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - - # 选择每个样本的最后一步 - last_steps_value = outputs_wm.logits_value[:, -1:, :] # 这将选择最后一列并保持维度不变 - # 使用torch.cat在第二个维度上连接原始act_tokens和last_steps - outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) - - last_steps_policy = outputs_wm.logits_policy[:, -1:, :] # 这将选择最后一列并保持维度不变 - outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) - - # Reshape your tensors - # outputs_wm.logits_value.shape (30,21) = (B*6, 21) - outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') - outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') - - - # return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value) - return outputs_wm - - - @torch.no_grad() - # @profile - def refresh_keys_values_with_initial_latent_state(self, latent_state: torch.LongTensor, reset_indices=None) -> torch.FloatTensor: - n, num_observations_tokens, _ = latent_state.shape - assert num_observations_tokens == self.num_observations_tokens - # self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, max_tokens=self.config.max_tokens) - if reset_indices is None: - self.keys_values_wm_list = [self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) for i in range(n)] - else: - for i in reset_indices: - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - outputs_wm = self.forward({'obs_embeddings': latent_state[i].unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False, kvcache_independent=False) - self.keys_values_wm_list[i] = self.keys_values_wm_single_env - self.keys_values_wm_size_list[i] = 1 - return None - - @torch.no_grad() - # @profile - def forward_initial_inference(self, obs_act_dict: torch.LongTensor, should_predict_next_obs: bool = True): - if isinstance(obs_act_dict, dict): - obs = obs_act_dict['obs'] - else: - obs = obs_act_dict - outputs_wm, latent_state = self.reset_from_initial_observations_v2(obs_act_dict) # root节点也有context - # outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict) # 从零开始 - - return outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value - - """ - 假设env_num=8 - 8个环境的kv_cache单独存储与寻找,都存储在一个dict中,在recurrent_inference时, - 由于不同环境找到的kv_cache的size不同,先根据最大size对kv_cache在前部补零,然后组成batch_size的kv_cache - 其内部也是通过batch执行transformer forward的推理 - """ - - @torch.no_grad() - # @profile - def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True): - # 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。 - # 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。 - # 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可 - # TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表? - - latest_state = state_action_history[-1][0] - - # 假设 latest_state 是新的 latent_state,包含 ready_env_num 个环境的信息 - ready_env_num = latest_state.shape[0] - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] - self.retrieve_or_generate_kvcache(latest_state, ready_env_num) - - num_passes = 1 + self.num_observations_tokens if should_predict_next_obs else 1 - output_sequence, latent_state = [], [] - - # reset_indices = [index for index, value in enumerate(self.keys_values_wm_size_list) if value + num_passes > self.config.max_tokens] - # self.refresh_keys_values_with_initial_latent_state(torch.tensor(latest_state, dtype=torch.float32).to(self.device), reset_indices) - - action = state_action_history[-1][-1] - token = action.clone().detach() if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long) - token = token.reshape(-1, 1).to(self.device) # (B, 1) - - # print(self.keys_values_wm_size_list) - # 获取self.keys_values_wm_size_list的最小值min_size - min_size = min(self.keys_values_wm_size_list) - if min_size >= self.config.max_tokens - 5: - self.length3_context_cnt += len(self.keys_values_wm_size_list) - if min_size >= 3: - self.length2_context_cnt += len(self.keys_values_wm_size_list) - # if self.total_query_count>0 and self.total_query_count%1==0: - if self.total_query_count>0 and self.total_query_count%10000==0: - self.hit_freq = self.hit_count/(self.total_query_count) - # print('hit_freq:', self.hit_freq) - # print('hit_count:', self.hit_count) - print('total_query_count:', self.total_query_count) - # 如果总查询次数大于0,计算并打印cnt的比率 - length3_context_cnt_ratio = self.length3_context_cnt / self.total_query_count - print('>=3 node context_cnt:', self.length3_context_cnt) - print('>=3 node context_cnt_ratio:', length3_context_cnt_ratio) - length2_context_cnt_ratio = self.length2_context_cnt / self.total_query_count - print('>=2 node context_cnt_ratio:', length2_context_cnt_ratio) - print('>=2 node context_cnt:', self.length2_context_cnt) - # print(self.keys_values_wm_size_list) - - # 输入self.keys_values_wm_list,输出为self.keys_values_wm - self.trim_and_pad_kv_cache() - # print(f'NOTE: in search node {self.keys_values_wm_size_list}') - for k in range(num_passes): # assumption that there is only one action token. - # action_token obs_token, ..., obs_token 1+1 - if k==0: - obs_embeddings_or_act_tokens = {'act_tokens': token} - else: - obs_embeddings_or_act_tokens = {'obs_embeddings': token} - - outputs_wm = self.forward(obs_embeddings_or_act_tokens, past_keys_values=self.keys_values_wm, is_root=False, kvcache_independent=False) - # if k==0, action_token self.head_observations 1,...,0,1 - output_sequence.append(outputs_wm.output_sequence) - - if k == 0: - # if k==0, token is action_token outputs_wm.logits_rewards 是有值的 - reward = outputs_wm.logits_rewards # (B,) - - if k < self.num_observations_tokens: - # 一共产生16个obs_token,每次产生一个 - # TODO: sample or argmax - # token = Categorical(logits=outputs_wm.logits_observations).sample() - # Use argmax to select the most likely token - # token = outputs_wm.logits_observations.argmax(-1, keepdim=True) - token = outputs_wm.logits_observations - # if len(token.shape) != 2: - # token = token.squeeze(-1) # Ensure the token tensor shape is (B, 1) - if len(token.shape) != 3: - token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) - latent_state.append(token) - - output_sequence = torch.cat(output_sequence, dim=1) # (B, 1 + K, E) - # Before updating self.latent_state, delete the old one to free memory - del self.latent_state - self.latent_state = torch.cat(latent_state, dim=1) # (B, K) - - self.update_cache(self.latent_state) - # TODO: 在计算结束后,是否需要更新最新的缓存. 是否需要deepcopy - - # del self.keys_values_wm - if len(self.past_keys_values_cache) > self.max_cache_size: - # TODO: lru_cache - _, popped_kv_cache = self.past_keys_values_cache.popitem(last=False) - del popped_kv_cache # 不要这一行 - - # Example usage: - # Assuming `past_keys_values_cache` is a populated instance of `KeysValues` - # and `num_layers` is the number of transformer layers - # cuda_memory_gb = self.calculate_cuda_memory_gb(self.past_keys_values_cache, num_layers=2) - # print(f'len(self.past_keys_values_cache): {len(self.past_keys_values_cache)}, Memory used by past_keys_values_cache: {cuda_memory_gb:.2f} GB') - - return outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value - - def trim_and_pad_kv_cache(self): - """ - This method trims and pads the key and value caches of the attention mechanism - to a consistent size across all items in the batch, determined by the smallest cache size. - """ - - # Find the minimum size across all key-value sizes for padding/trimming - min_size = min(self.keys_values_wm_size_list) - - # Iterate over each layer of the transformer - for layer in range(self.num_layers): - # Initialize lists to hold the trimmed and padded k and v caches - kv_cache_k_list = [] - kv_cache_v_list = [] - - # Enumerate over the key-value pairs list - for idx, keys_values in enumerate(self.keys_values_wm_list): - # Retrieve the current layer's key and value caches - k_cache = keys_values[layer]._k_cache._cache - v_cache = keys_values[layer]._v_cache._cache - - # Get the effective size of the current cache - effective_size = self.keys_values_wm_size_list[idx] - # Calculate the size difference to trim - trim_size = effective_size - min_size if effective_size > min_size else 0 - - # If trimming is needed, remove 'trim_size' from the beginning of the cache - if trim_size > 0: - k_cache_trimmed = k_cache[:, :, trim_size:, :] - v_cache_trimmed = v_cache[:, :, trim_size:, :] - # Pad the trimmed cache with zeros on the third dimension - k_cache_padded = F.pad(k_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) - v_cache_padded = F.pad(v_cache_trimmed, (0, 0, trim_size, 0), "constant", 0) - else: - k_cache_padded = k_cache - v_cache_padded = v_cache - - # Add the processed caches to the lists - kv_cache_k_list.append(k_cache_padded) - kv_cache_v_list.append(v_cache_padded) - - # Stack the caches along the new dimension, and remove the extra dimension with squeeze() - self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) - self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) - - # Update the cache size to the minimum size after trimming and padding - self.keys_values_wm._keys_values[layer]._k_cache._size = min_size - self.keys_values_wm._keys_values[layer]._v_cache._size = min_size - - def update_cache(self, latent_state): - for i in range(latent_state.size(0)): # Iterate over each environment - state_single_env = latent_state[i] # Get the latent state for a single environment - quantized_state = state_single_env.detach().cpu().numpy() # Detach and move the state to CPU - cache_key = quantize_state(quantized_state) # Quantize state and compute its hash value as cache key - - # Copy keys and values from the global cache to a single environment cache - for layer in range(self.num_layers): - if self.keys_values_wm._keys_values[layer]._k_cache._size < self.config.max_tokens - 1: - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # Shape torch.Size([2, 100, 512]) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size - elif self.keys_values_wm._keys_values[layer]._k_cache._size == self.config.max_tokens - 1: - # 裁剪和填充逻辑 - # 假设cache的维度是 [batch_size, num_heads, sequence_length, features] - k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] - v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] - - # 移除前2步并保留最近的max_tokens - 3步 - k_cache_trimmed = k_cache_current[:, 2:self.config.max_tokens - 1, :] - v_cache_trimmed = v_cache_current[:, 2:self.config.max_tokens - 1, :] - - # 沿第3维填充后2步 - padding_size = (0, 0, 0, 3) #F.pad的参数(0, 0, 0, 2)指定了在每个维度上的填充量。这些参数是按(左, 右, 上, 下)的顺序给出的,对于三维张量来说,分别对应于(维度2左侧, 维度2右侧, 维度1左侧, 维度1右侧)的填充。 - k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) - v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) - # 更新单环境cache - self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) - self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) - - self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.config.max_tokens - 3 - self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.config.max_tokens - 3 - - - - # Compare and store the larger cache - if cache_key in self.past_keys_values_cache: - existing_kvcache = self.past_keys_values_cache[cache_key] - # Check if there is a size difference between existing cache and new cache - if self.keys_values_wm_single_env._keys_values[0]._k_cache._size > existing_kvcache.size and self.keys_values_wm_single_env._keys_values[0]._k_cache._size < self.config.max_tokens - 1: - # Only store if size is less than max_tokens - 1 to avoid reset - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - elif self.keys_values_wm_single_env._keys_values[0]._k_cache._size < self.config.max_tokens - 1: - # Only store if size is less than max_tokens - 1 to avoid reset - self.past_keys_values_cache[cache_key] = copy.deepcopy(self.to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) - - def retrieve_or_generate_kvcache(self, latent_state, ready_env_num): - """ - This method iterates over the environments, retrieves a matching cache if available, - or generates a new one otherwise. It updates the lists with the keys_values caches and their sizes. - """ - for i in range(ready_env_num): - self.total_query_count += 1 - state_single_env = latent_state[i] # Get the latent state for a single environment - cache_key = quantize_state(state_single_env) # Compute the hash value using the quantized state - # Retrieve the cached value if it exists - matched_value = self.past_keys_values_cache.get(cache_key) - if matched_value is not None: - # If a matching value is found, add it to the list - self.hit_count += 1 - # Deepcopy is needed because the transformer's forward may modify matched_value in place - self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, self.device))) - self.keys_values_wm_size_list.append(matched_value.size) - else: - # If no match is found, use a zero reset - self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.config.max_tokens) - self.forward({'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, past_keys_values=self.keys_values_wm_single_env, is_root=False) - self.keys_values_wm_list.append(self.keys_values_wm_single_env) - self.keys_values_wm_size_list.append(1) - - def to_device_for_kvcache(self, keys_values: KeysValues, device: str) -> KeysValues: - """ - Transfer all KVCache objects within the KeysValues object to a certain device. - - Arguments: - - keys_values (KeysValues): The KeysValues object to be transferred. - - device (str): The device to transfer to. - - Returns: - - keys_values (KeysValues): The KeysValues object with its caches transferred to the specified device. - """ - # Check if CUDA is available and select the first available CUDA device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - for kv_cache in keys_values: - kv_cache._k_cache._cache = kv_cache._k_cache._cache.to(device) - kv_cache._v_cache._cache = kv_cache._v_cache._cache.to(device) - return keys_values - - - # 计算显存使用量的函数 - def calculate_cuda_memory_gb(self, past_keys_values_cache, num_layers: int): - total_memory_bytes = 0 - - # 遍历OrderedDict中所有的KeysValues实例 - for kv_instance in past_keys_values_cache.values(): - num_layers = len(kv_instance) # 获取层数 - for layer in range(num_layers): - kv_cache = kv_instance[layer] - k_shape = kv_cache._k_cache.shape # 获取keys缓存的形状 - v_shape = kv_cache._v_cache.shape # 获取values缓存的形状 - - # 计算元素个数并乘以每个元素的字节数 - k_memory = torch.prod(torch.tensor(k_shape)) * 4 - v_memory = torch.prod(torch.tensor(v_shape)) * 4 - - # 累加keys和values缓存的内存 - layer_memory = k_memory + v_memory - total_memory_bytes += layer_memory.item() # .item()确保转换为Python标准数字 - - # 将总内存从字节转换为吉字节 - total_memory_gb = total_memory_bytes / (1024 ** 3) - return total_memory_gb - - # @profile - def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -> LossWithIntermediateLosses: - # NOTE: 这里是需要梯度的 - #with torch.no_grad(): # TODO: 非常重要 - obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - obs_embeddings.register_hook(lambda grad: grad * 1/5) # TODO:只提供重建损失更新表征网络 - # obs_embeddings.register_hook(lambda grad: grad * 1) # TODO:只提供重建损失更新表征网络 - - # Assume that 'cont_embeddings' and 'original_images' are available from prior code - # Decode the embeddings to reconstruct the images - reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # Calculate the reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 4, 64, 64), reconstructed_images) # TODO: for stack=4 - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) # for stack=4 gray obs - - latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # TODO: for stack=1 - - # latent_recon_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - # perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - latent_kl_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - - act_tokens = rearrange(batch['actions'], 'b l -> b l 1') - - # TODO: 是否只用重建损失更新表征网络 非常重要 - outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, is_root=False) - # outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens)}, is_root=False) - - with torch.no_grad(): - traget_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], should_preprocess=False) # (B, C, H, W) -> (B, K, E) - - labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(traget_obs_embeddings, batch['rewards'], - batch['ends'], - batch['mask_padding']) - # labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(obs_embeddings, batch['rewards'], - # batch['ends'], - # batch['mask_padding']) - - logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') - # labels_observations = labels_observations.contiguous().view(-1, self.projection_input_dim) # TODO: - labels_observations = labels_observations.reshape(-1, self.projection_input_dim) # TODO: - - - loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1) - mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) # TODO: - # mask_padding_expanded = batch['mask_padding'][:, 1:].reshape(-1) - - # 应用mask到loss_obs - loss_obs = (loss_obs * mask_padding_expanded).mean(-1) - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) - - loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') - loss_policy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') - loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - return LossWithIntermediateLosses(latent_recon_loss_weight=self.latent_recon_loss_weight, perceptual_loss_weight=self.perceptual_loss_weight, loss_obs=loss_obs, loss_rewards=loss_rewards, loss_value=loss_value, - loss_policy=loss_policy, latent_kl_loss=latent_kl_loss, latent_recon_loss=latent_recon_loss, perceptual_loss=perceptual_loss) - - def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): - # Assume outputs.logits_rewards and labels are your predictions and targets - # And mask_padding is a boolean tensor with True at positions to keep and False at positions to ignore - - if element == 'rewards': - logits = outputs.logits_rewards - elif element == 'policy': - logits = outputs.logits_policy - elif element == 'value': - logits = outputs.logits_value - - # Reshape your tensors - logits_rewards = rearrange(logits, 'b t e -> (b t) e') - labels = labels.reshape(-1, labels.shape[-1]) # Assuming labels originally has shape [b, t, reward_dim] - - # Reshape your mask. True means valid data. - mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') - - loss_rewards = -(torch.log_softmax(logits_rewards, dim=1) * labels).sum(1) - # loss_rewards = (loss_rewards * mask_padding.squeeze(-1).float()).mean() - loss_rewards = (loss_rewards * mask_padding).mean() - - - return loss_rewards - - def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor]: - assert torch.all(ends.sum(dim=1) <= 1) # each sequence sample has at most 1 done - mask_fill = torch.logical_not(mask_padding) - labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] # self.projection_input_dim - - - # labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1} TODO(pu) - - mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) - labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) - - labels_ends = ends.masked_fill(mask_fill, -100) - return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) - - def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[ - torch.Tensor, torch.Tensor]: - - mask_fill = torch.logical_not(mask_padding) - mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) - labels_policy = target_policy.masked_fill(mask_fill_policy, -100) - - mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) - labels_value = target_value.masked_fill(mask_fill_value, -100) - return labels_policy.reshape(-1, self.action_shape), labels_value.reshape(-1, self.support_size) # TODO(pu) - - - def negative_cosine_similarity(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """ - Overview: - consistency loss function: the negative cosine similarity. - Arguments: - - x1 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - - x2 (:obj:`torch.Tensor`): shape (batch_size, dim), e.g. (256, 512) - Returns: - (x1 * x2).sum(dim=1) is the cosine similarity between vector x1 and x2. - The cosine similarity always belongs to the interval [-1, 1]. - For example, two proportional vectors have a cosine similarity of 1, - two orthogonal vectors have a similarity of 0, - and two opposite vectors have a similarity of -1. - -(x1 * x2).sum(dim=1) is consistency loss, i.e. the negative cosine similarity. - Reference: - https://en.wikipedia.org/wiki/Cosine_similarity - """ - x1 = F.normalize(x1, p=2., dim=-1, eps=1e-5) - x2 = F.normalize(x2, p=2., dim=-1, eps=1e-5) - return -(x1 * x2).sum(dim=1) - - - def render_img(self, obs: int, rec_img: int): - import torch - from PIL import Image - import matplotlib.pyplot as plt - - # 假设batch是一个字典,其中包含了observations键, - # 并且它的形状是torch.Size([B, N, C, H, W]) - # batch_observations = batch_for_gpt['observations'] - # batch_observations = batch['observations'] - batch_observations = obs.unsqueeze(0) - # batch_observations = rec_img.unsqueeze(0) - - # batch_observations = observations.unsqueeze(0) - # batch_observations = x.unsqueeze(0) - # batch_observations = reconstructions.unsqueeze(0) - - - - B, N, C, H, W = batch_observations.shape # 自动检测维度 - - # 分隔条的宽度(可以根据需要调整) - separator_width = 2 - - # 遍历每个样本 - for i in range(B): - # 提取当前样本中的所有帧 - frames = batch_observations[i] - - # 计算拼接图像的总宽度(包括分隔条) - total_width = N * W + (N - 1) * separator_width - - # 创建一个新的图像,其中包含分隔条 - concat_image = Image.new('RGB', (total_width, H), color='black') - - # 拼接每一帧及分隔条 - for j in range(N): - frame = frames[j].permute(1, 2, 0).cpu().numpy() # 转换为(H, W, C) - frame_image = Image.fromarray((frame * 255).astype('uint8'), 'RGB') - - # 计算当前帧在拼接图像中的位置 - x_position = j * (W + separator_width) - concat_image.paste(frame_image, (x_position, 0)) - - # 显示图像 - plt.imshow(concat_image) - plt.title(f'Sample {i+1}') - plt.axis('off') # 关闭坐标轴显示 - plt.show() - - # 保存图像到文件 - concat_image.save(f'sample_{i+1}.png') \ No newline at end of file diff --git a/lzero/policy/muzero_gpt_1219.py b/lzero/policy/muzero_gpt_1219.py deleted file mode 100644 index 1126c4475..000000000 --- a/lzero/policy/muzero_gpt_1219.py +++ /dev/null @@ -1,1020 +0,0 @@ -import copy -from collections import defaultdict -from typing import List, Dict, Any, Tuple, Union - -import numpy as np -import torch -import torch.optim as optim -from ding.model import model_wrap -from ding.policy.base_policy import Policy -from ding.torch_utils import to_tensor -from ding.utils import POLICY_REGISTRY -from torch.distributions import Categorical -from torch.nn import L1Loss - -from lzero.mcts import MuZeroMCTSCtree as MCTSCtree -from lzero.mcts import MuZeroMCTSPtree as MCTSPtree -from lzero.model import ImageTransforms -from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ - prepare_obs, configure_optimizers - - -def configure_optimizer(model, learning_rate, weight_decay, exclude_submodules, *blacklist_module_names): - """Credits to https://github.com/karpathy/minGPT""" - # separate out all parameters to those that will and won't experience regularizing weight decay - decay = set() - no_decay = set() - whitelist_weight_modules = [torch.nn.Linear, torch.nn.Conv1d] - blacklist_weight_modules = [torch.nn.LayerNorm, torch.nn.Embedding] - - # Here, we make sure to exclude parameters from specified submodules when creating param_dict - param_dict = {} - for mn, m in model.named_modules(): - if any(mn.startswith(module_name) for module_name in exclude_submodules): - continue # skip parameters from excluded submodules - for pn, p in m.named_parameters(recurse=False): - fpn = f'{mn}.{pn}' if mn else pn # full param name - if not any(fpn.startswith(bl_module_name) for bl_module_name in blacklist_module_names): - param_dict[fpn] = p - if 'bias' in pn: - no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, tuple(whitelist_weight_modules)): - decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, tuple(blacklist_weight_modules)): - no_decay.add(fpn) - else: - decay.add(fpn) # Default behavior is to add to decay - - # Validate that we considered every parameter - inter_params = decay & no_decay - union_params = decay | no_decay - assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!" - assert len(param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" - - # Create the PyTorch optimizer object - optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, - ] - optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate) - return optimizer - - - -@POLICY_REGISTRY.register('muzero_gpt') -class MuZeroGPTPolicy(Policy): - """ - Overview: - The policy class for MuZero. - """ - - # The default_config for MuZero policy. - config = dict( - model=dict( - # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. - model_type='conv', # options={'mlp', 'conv'} - # (bool) If True, the action space of the environment is continuous, otherwise discrete. - continuous_action_space=False, - # (tuple) The stacked obs shape. - # observation_shape=(1, 96, 96), # if frame_stack_num=1 - observation_shape=(4, 96, 96), # if frame_stack_num=4 - # (bool) Whether to use the self-supervised learning loss. - self_supervised_learning_loss=False, - # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. - categorical_distribution=True, - # (int) The image channel in image observation. - image_channel=1, - # (int) The number of frames to stack together. - frame_stack_num=1, - # (int) The number of res blocks in MuZero model. - num_res_blocks=1, - # (int) The number of channels of hidden states in MuZero model. - num_channels=64, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, - # (bool) whether to learn bias in the last linear layer in value and policy head. - bias=True, - # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. - discrete_action_encoding_type='one_hot', - # (bool) whether to use res connection in dynamics. - res_connection_in_dynamics=True, - # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', - ), - # ****** common ****** - # (bool) whether to use rnd model. - use_rnd_model=False, - # (bool) Whether to use multi-gpu training. - multi_gpu=False, - # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) - # this variable is used in ``collector``. - sampled_algo=False, - # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) - gumbel_algo=False, - # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. - mcts_ctree=True, - # (bool) Whether to use cuda for network. - cuda=True, - # (int) The number of environments used in collecting data. - collector_env_num=8, - # (int) The number of environments used in evaluating policy. - evaluator_env_num=3, - # (str) The type of environment. Options are ['not_board_games', 'board_games']. - env_type='not_board_games', - # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. - battle_mode='play_with_bot_mode', - # (bool) Whether to monitor extra statistics in tensorboard. - monitor_extra_statistics=True, - # (int) The transition number of one ``GameSegment``. - game_segment_length=200, - - # ****** observation ****** - # (bool) Whether to transform image to string to save memory. - transform2string=False, - # (bool) Whether to use gray scale image. - gray_scale=False, - # (bool) Whether to use data augmentation. - use_augmentation=False, - # (list) The style of augmentation. - augmentation=['shift', 'intensity'], - - # ******* learn ****** - # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. - # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, - # we should set it to True to avoid the influence of the done flag. - ignore_done=False, - # (int) How many updates(iterations) to train after collector's one collection. - # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... - # For different env, we have different episode_length, - # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. - update_per_collect=None, - # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, - # (int) Minibatch size for one gradient descent. - batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam'] - optim_type='SGD', - # (float) Learning rate for training policy network. Initial lr for manually decay schedule. - learning_rate=0.2, - # (int) Frequency of target network update. - target_update_freq=100, - # (int) Frequency of target network update. - target_update_freq_for_intrinsic_reward=1000, - # (float) Weight decay for training policy network. - weight_decay=1e-4, - # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). - momentum=0.9, - # (float) The maximum constraint value of gradient norm clipping. - grad_clip_value=10, - # (int) The number of episodes in each collecting stage. - n_episode=8, - # (int) the number of simulations in MCTS. - num_simulations=50, - # (float) Discount factor (gamma) for returns. - discount_factor=0.997, - # (int) The number of steps for calculating target q_value. - td_steps=5, - # (int) The number of unroll steps in dynamics network. - num_unroll_steps=5, - # (float) The weight of reward loss. - reward_loss_weight=1, - # (float) The weight of value loss. - value_loss_weight=0.25, - # (float) The weight of policy loss. - policy_loss_weight=1, - # (float) The weight of policy entropy loss. - policy_entropy_loss_weight=0, - # (float) The weight of ssl (self-supervised learning) loss. - ssl_loss_weight=0, - # (bool) Whether to use piecewise constant learning rate decay. - # i.e. lr: 0.2 -> 0.02 -> 0.002 - lr_piecewise_constant_decay=True, - # (int) The number of final training iterations to control lr decay, which is only used for manually decay. - threshold_training_steps_for_final_lr=int(5e4), - # (bool) Whether to use manually decayed temperature. - manual_temperature_decay=False, - # (int) The number of final training iterations to control temperature, which is only used for manually decay. - threshold_training_steps_for_final_temperature=int(1e5), - # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. - # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. - fixed_temperature_value=0.25, - # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. - use_ture_chance_label_in_chance_encoder=False, - - # ****** Priority ****** - # (bool) Whether to use priority when sampling training data from the buffer. - use_priority=True, - # (float) The degree of prioritization to use. A value of 0 means no prioritization, - # while a value of 1 means full prioritization. - priority_prob_alpha=0.6, - # (float) The degree of correction to use. A value of 0 means no correction, - # while a value of 1 means full correction. - priority_prob_beta=0.4, - - # ****** UCB ****** - # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. - root_dirichlet_alpha=0.3, - # (float) The noise weight at the root node of the search tree. - root_noise_weight=0.25, - - # ****** Explore by random collect ****** - # (int) The number of episodes to collect data randomly before training. - random_collect_episode_num=0, - - # ****** Explore by eps greedy ****** - eps=dict( - # (bool) Whether to use eps greedy exploration in collecting data. - eps_greedy_exploration_in_collect=False, - # (str) The type of decaying epsilon. Options are 'linear', 'exp'. - type='linear', - # (float) The start value of eps. - start=1., - # (float) The end value of eps. - end=0.05, - # (int) The decay steps from start to end eps. - decay=int(1e5), - ), - ) - - def default_model(self) -> Tuple[str, List[str]]: - """ - Overview: - Return this algorithm default model setting for demonstration. - Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. - - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. - - import_names (:obj:`List[str]`): The model class path list used in this algorithm. - .. note:: - The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For MuZero, ``lzero.model.muzero_gpt_model.MuZeroModel`` - """ - if self._cfg.model.model_type == "conv": - # return 'MuZeroModel', ['lzero.model.muzero_gpt_model'] - return 'MuZeroModelGPT', ['lzero.model.muzero_gpt_model'] - elif self._cfg.model.model_type == "mlp": - return 'MuZeroModelGPT', ['lzero.model.muzero_gpt_model_vector_obs'] - else: - raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) - - def _init_learn(self) -> None: - """ - Overview: - Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. - """ - # assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type - # # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. - # if self._cfg.optim_type == 'SGD': - # self._optimizer = optim.SGD( - # self._model.parameters(), - # lr=self._cfg.learning_rate, - # momentum=self._cfg.momentum, - # weight_decay=self._cfg.weight_decay, - # ) - # elif self._cfg.optim_type == 'Adam': - # self._optimizer = optim.Adam( - # self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay - # ) - # elif self._cfg.optim_type == 'AdamW': - # self._optimizer = configure_optimizers( - # model=self._model, - # weight_decay=self._cfg.weight_decay, - # learning_rate=self._cfg.learning_rate, - # device_type=self._cfg.device - # ) - - # self._optimizer_tokenizer = optim.Adam( - # self._model.tokenizer.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay - # ) - - # self._optimizer_tokenizer = optim.Adam( - # self._model.tokenizer.parameters(), lr=self._cfg.learning_rate # weight_decay=0 - # ) - - # # TODO: nanoGPT optimizer - # self._optimizer_world_model = configure_optimizer( - # model=self._model.world_model, - # learning_rate=self._cfg.learning_rate, - # weight_decay=self._cfg.weight_decay, - # # weight_decay=0.01, - # exclude_submodules=['tokenizer'] - # ) - - self._optimizer_tokenizer = optim.Adam( - self._model.tokenizer.parameters(), lr=1e-4 # weight_decay=0 - ) - - # TODO: nanoGPT optimizer - self._optimizer_world_model = configure_optimizer( - model=self._model.world_model, - learning_rate=3e-3, - weight_decay=self._cfg.weight_decay, - # weight_decay=0.01, - exclude_submodules=['tokenizer'] - ) - - # self._optimizer_world_model = configure_optimizers( - # model=self._model.world_model, - # weight_decay=self._cfg.weight_decay, - # learning_rate=self._cfg.learning_rate, - # device_type=self._cfg.device - # ) - - # if self._cfg.lr_piecewise_constant_decay: - # from torch.optim.lr_scheduler import LambdaLR - # max_step = self._cfg.threshold_training_steps_for_final_lr - # # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. - # lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa - # self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) - - # use model_wrapper for specialized demands of different modes - self._target_model = copy.deepcopy(self._model) - - # TODO: torch 2.0 - self._model = torch.compile(self._model) - self._target_model = torch.compile(self._target_model) - - - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.target_update_freq} - ) - self._learn_model = self._model - - - # TODO: only for debug - # for param in self._learn_model.tokenizer.parameters(): - # param.requires_grad = False - - if self._cfg.use_augmentation: - self.image_transforms = ImageTransforms( - self._cfg.augmentation, - image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) - ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) - - if self._cfg.use_rnd_model: - if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': - self._target_model_for_intrinsic_reward = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} - ) - elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': - self._target_model_for_intrinsic_reward = model_wrap( - self._target_model, - wrapper_name='target', - update_type='momentum', - update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} - ) - - def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - # current_batch, target_batch, train_which_component_dict = data - if data[-1]['train_which_component'] == 'transformer': - return_loss_dict = self._forward_learn_transformer(data) - elif data[-1]['train_which_component'] == 'tokenizer': - return_loss_dict = self._forward_learn_tokenizer(data) - else: - ValueError('Unknown component type') - - return return_loss_dict - - - def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - self._learn_model.train() - self._target_model.train() - self._learn_model.tokenizer.eval() - - if self._cfg.use_rnd_model: - self._target_model_for_intrinsic_reward.train() - - # current_batch, target_batch = data - current_batch, target_batch, train_which_component_dict = data - - - obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch - target_reward, target_value, target_policy = target_batch - - obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - - # do augmentations - if self._cfg.use_augmentation: - obs_batch = self.image_transforms.transform(obs_batch) - if self._cfg.model.self_supervised_learning_loss: - obs_target_batch = self.image_transforms.transform(obs_target_batch) - - # shape: (batch_size, num_unroll_steps, action_dim) - # NOTE: .long(), in discrete action space. - action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() - data_list = [ - mask_batch, - target_reward.astype('float32'), - target_value.astype('float32'), target_policy, weights - ] - [mask_batch, target_reward, target_value, target_policy, - weights] = to_torch_float_tensor(data_list, self._cfg.device) - - target_reward = target_reward.view(self._cfg.batch_size, -1) - target_value = target_value.view(self._cfg.batch_size, -1) - - assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) - - # ``scalar_transform`` to transform the original value to the scaled value, - # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - transformed_target_reward = scalar_transform(target_reward) - transformed_target_value = scalar_transform(target_value) - - # transform a scalar to its categorical_distribution. After this transformation, each scalar is - # represented as the linear combination of its two adjacent supports. - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) - - # compute_loss(self, batch: Batch, tokenizer: Tokenizer, ** kwargs: Any) - - batch_for_gpt = {} - # TODO: for cartpole self._cfg.model.observation_shape - if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape)==1: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - elif len(self._cfg.model.observation_shape)==3: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, *self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - - - batch_for_gpt['actions'] = action_batch.squeeze(-1) # (B, T-1, A) -> (B, T-1) - - batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] # (B, T, R) -> (B, T-1, R) - - batch_for_gpt['mask_padding'] = mask_batch == 1.0 # (B, T) NOTE: 0 means invalid padding data - batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] # (B, T-1) TODO - - - batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # (B, T-1, O) or (B, T-1, C, H, W) - batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) # (B, T-1) - - batch_for_gpt['target_value'] = target_value_categorical[:, :-1] # (B, T-1, V) - batch_for_gpt['target_policy'] = target_policy[:, :-1] # (B, T-1, A) - # NOTE: TODO: next latent state's policy value - # batch_for_gpt['target_value'] = target_value_categorical[:, 1:] # (B, T-1, V) - # batch_for_gpt['target_policy'] = target_policy[:, 1:] # (B, T-1, A) - - # self._learn_model.world_model.train() - - # get valid target_policy data - valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] - # compute entropy of each policy - target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) - # compute average entropy - average_target_policy_entropy = target_policy_entropy.mean().item() - # print(f'Average entropy: {average_entropy}') - - - # if train_which_component_dict['train_which_component'] == 'transformer': - # ============================================================== - # update world model - # ============================================================== - intermediate_losses = defaultdict(float) - losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._learn_model.tokenizer) - weighted_total_loss = losses.loss_total - for loss_name, loss_value in losses.intermediate_losses.items(): - intermediate_losses[f"{loss_name}"] = loss_value - # print(intermediate_losses) - obs_loss = intermediate_losses['loss_obs'] - reward_loss = intermediate_losses['loss_rewards'] - policy_loss = intermediate_losses['loss_policy'] - value_loss = intermediate_losses['loss_value'] - - # ============================================================== - # the core learn model update step. - # ============================================================== - """ - for name, parameter in self._learn_model.tokenizer.named_parameters(): - print(name) - """ - gradient_scale = 1 / self._cfg.num_unroll_steps - # TODO(pu): test the effect of gradient scale. - weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) - self._optimizer_world_model.zero_grad() - weighted_total_loss.backward() - if self._cfg.multi_gpu: - self.sync_gradients(self._learn_model) - total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( - self._learn_model.world_model.parameters(), self._cfg.grad_clip_value - ) - self._optimizer_world_model.step() - if self._cfg.lr_piecewise_constant_decay: - self.lr_scheduler.step() - - - # ============================================================== - # the core target model update step. - # ============================================================== - self._target_model.update(self._learn_model.state_dict()) - if self._cfg.use_rnd_model: - self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) - - - return_loss_dict = { - 'collect_mcts_temperature': self._collect_mcts_temperature, - 'collect_epsilon': self.collect_epsilon, - 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], - 'cur_lr_tokenizer': self._optimizer_tokenizer.param_groups[0]['lr'], - - 'weighted_total_loss': weighted_total_loss.item(), - 'obs_loss': obs_loss, - 'policy_loss': policy_loss, - 'target_policy_entropy': average_target_policy_entropy, - # 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), - 'reward_loss': reward_loss, - 'value_loss': value_loss, - # 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, - - # ============================================================== - # priority related - # ============================================================== - # 'value_priority_orig': value_priority, - 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO - # 'value_priority': value_priority.mean().item(), - 'target_reward': target_reward.detach().cpu().numpy().mean().item(), - 'target_value': target_value.detach().cpu().numpy().mean().item(), - 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), - 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), - # 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), - # 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), - 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), - } - - return return_loss_dict - - - def _forward_learn_tokenizer(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - self._learn_model.train() - self._target_model.train() - if self._cfg.use_rnd_model: - self._target_model_for_intrinsic_reward.train() - - # current_batch, target_batch = data - current_batch, target_batch, train_which_component_dict = data - - - obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch - target_reward, target_value, target_policy = target_batch - - obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - - # do augmentations - if self._cfg.use_augmentation: - obs_batch = self.image_transforms.transform(obs_batch) - if self._cfg.model.self_supervised_learning_loss: - obs_target_batch = self.image_transforms.transform(obs_target_batch) - - # shape: (batch_size, num_unroll_steps, action_dim) - # NOTE: .long(), in discrete action space. - # action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() - # data_list = [ - # mask_batch, - # target_reward.astype('float32'), - # target_value.astype('float32'), target_policy, weights - # ] - - # [mask_batch, target_reward, target_value, target_policy, - # weights] = to_torch_float_tensor(data_list, self._cfg.device) - - # target_reward = target_reward.view(self._cfg.batch_size, -1) - # target_value = target_value.view(self._cfg.batch_size, -1) - - # assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) - - - batch_for_gpt = {} - # TODO: for cartpole self._cfg.model.observation_shape - if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape)==1: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - elif len(self._cfg.model.observation_shape)==3: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, *self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - - batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # (B, T-1, O) or (B, T-1, C, H, W) - - # if train_which_component_dict['train_which_component'] == 'tokenizer': - - # ============================================================== - # update tokenizer - # ============================================================== - # TODO: train tokenlizer - self._learn_model.tokenizer.train() - - # for name, param in self._learn_model.tokenizer.named_parameters(): - # if param.requires_grad: - # print(name, param.shape) - - losses_tokenizer = self._learn_model.tokenizer.compute_loss(batch_for_gpt) - - self._optimizer_tokenizer.zero_grad() - - weighted_total_loss_tokenizer = losses_tokenizer.loss_total - weighted_total_loss_tokenizer.backward() - # losses_tokenizer.loss_total.backward() - total_grad_norm_before_clip_tokenizer = torch.nn.utils.clip_grad_norm_( - self._learn_model.tokenizer.parameters(), self._cfg.grad_clip_value - ) - self._optimizer_tokenizer.step() - - intermediate_losses_tokenizer= defaultdict(float) - for loss_name, loss_value in losses_tokenizer.intermediate_losses.items(): - intermediate_losses_tokenizer[f"{loss_name}"] = loss_value - # print(intermediate_losses) - commitment_loss= intermediate_losses_tokenizer['commitment_loss'] - reconstruction_loss = intermediate_losses_tokenizer['reconstruction_loss'] - perceptual_loss = intermediate_losses_tokenizer['perceptual_loss'] - - - # # ============================================================== - # # the core target model update step. - # # ============================================================== - # self._target_model.update(self._learn_model.state_dict()) - # if self._cfg.use_rnd_model: - # self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) - - - return_loss_dict = { - 'collect_mcts_temperature': self._collect_mcts_temperature, - 'collect_epsilon': self.collect_epsilon, - 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], - 'cur_lr_tokenizer': self._optimizer_tokenizer.param_groups[0]['lr'], - - # 'weighted_total_loss': weighted_total_loss.item(), - # 'obs_loss': obs_loss, - # 'policy_loss': policy_loss, - # 'target_policy_entropy': average_target_policy_entropy, - # 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), - # 'reward_loss': reward_loss, - # 'value_loss': value_loss, - # 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, - - # ============================================================== - # priority related - # ============================================================== - # 'value_priority_orig': value_priority, - # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO - # 'value_priority': value_priority.mean().item(), - # 'target_reward': target_reward.detach().cpu().numpy().mean().item(), - # 'target_value': target_value.detach().cpu().numpy().mean().item(), - # 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), - # 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), - # 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), - # 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), - 'total_grad_norm_before_clip_tokenizer': total_grad_norm_before_clip_tokenizer.item(), - 'commitment_loss':commitment_loss, - 'reconstruction_loss':reconstruction_loss, - 'perceptual_loss': perceptual_loss, - } - - return return_loss_dict - - def _init_collect(self) -> None: - """ - Overview: - Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. - """ - self._collect_model = self._model - - if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) - else: - self._mcts_collect = MCTSPtree(self._cfg) - self._collect_mcts_temperature = 1. - self.collect_epsilon = 0.0 - - def _forward_collect( - self, - data: torch.Tensor, - action_mask: list = None, - temperature: float = 1, - to_play: List = [-1], - epsilon: float = 0.25, - ready_env_id: np.array = None, - ) -> Dict: - """ - Overview: - The forward function for collecting data in collect mode. Use model to execute MCTS search. - Choosing the action through sampling during the collect mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - temperature (:obj:`float`): The temperature of the policy. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - temperature: :math:`(1, )`. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._collect_model.eval() - self._collect_mcts_temperature = temperature - self.collect_epsilon = epsilon - active_collect_env_num = data.shape[0] - # if active_collect_env_num == 1: - # print('debug') - with torch.no_grad(): - network_output = self._collect_model.initial_inference(data) - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() - latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() - - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] - # the only difference between collect and eval is the dirichlet noise - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(active_collect_env_num, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(active_collect_env_num, legal_actions) - - roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) - self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) - - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` - roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} - - data_id = [i for i in range(active_collect_env_num)] - output = {i: None for i in data_id} - - if ready_env_id is None: - ready_env_id = np.arange(active_collect_env_num) - - for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - if self._cfg.eps.eps_greedy_exploration_in_collect: - # eps greedy collect - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=True - ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - if np.random.rand() < self.collect_epsilon: - action = np.random.choice(legal_actions[i]) - else: - # normal collect - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=False - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - } - - return output - - - def _init_eval(self) -> None: - """ - Overview: - Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. - """ - self._eval_model = self._model - if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) - else: - self._mcts_eval = MCTSPtree(self._cfg) - - def _get_target_obs_index_in_step_k(self, step): - """ - Overview: - Get the begin index and end index of the target obs in step k. - Arguments: - - step (:obj:`int`): The current step k. - Returns: - - beg_index (:obj:`int`): The begin index of the target obs in step k. - - end_index (:obj:`int`): The end index of the target obs in step k. - Examples: - >>> self._cfg.model.model_type = 'conv' - >>> self._cfg.model.image_channel = 3 - >>> self._cfg.model.frame_stack_num = 4 - >>> self._get_target_obs_index_in_step_k(0) - >>> (0, 12) - """ - if self._cfg.model.model_type == 'conv': - beg_index = self._cfg.model.image_channel * step - end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) - elif self._cfg.model.model_type == 'mlp': - beg_index = self._cfg.model.observation_shape * step - end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) - return beg_index, end_index - - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict: - """ - Overview: - The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. - Choosing the action with the highest value (argmax) rather than sampling during the eval mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._eval_model.eval() - active_eval_env_num = data.shape[0] - with torch.no_grad(): - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - network_output = self._collect_model.initial_inference(data) - latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - - if not self._eval_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) - latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) - - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(active_eval_env_num, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(active_eval_env_num, legal_actions) - roots.prepare_no_noise(reward_roots, policy_logits, to_play) - self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) - - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` - roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} - - data_id = [i for i in range(active_eval_env_num)] - output = {i: None for i in data_id} - - if ready_env_id is None: - ready_env_id = np.arange(active_eval_env_num) - - for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than - # sampling during the evaluation phase. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=1, deterministic=True - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the - # entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - } - - return output - - def _monitor_vars_learn(self) -> List[str]: - """ - Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in - tensorboard according to the return value ``_forward_learn``. - """ - return [ - 'collect_epsilon', - 'collect_mcts_temperature', - # 'cur_lr', - 'cur_lr_world_model', - 'cur_lr_tokenizer', - - 'weighted_total_loss', - # 'total_loss', - 'obs_loss', - 'policy_loss', - # 'policy_entropy', - 'target_policy_entropy', - 'reward_loss', - 'value_loss', - 'consistency_loss', - 'value_priority', - 'target_reward', - 'target_value', - # 'predicted_rewards', - # 'predicted_values', - # 'transformed_target_reward', - # 'transformed_target_value', - 'total_grad_norm_before_clip_tokenizer', - 'total_grad_norm_before_clip_wm', - # tokenizer - 'commitment_loss', - 'reconstruction_loss', - 'perceptual_loss', - ] - - def _state_dict_learn(self) -> Dict[str, Any]: - """ - Overview: - Return the state_dict of learn mode, usually including model, target_model and optimizer. - Returns: - - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. - """ - return { - 'model': self._learn_model.state_dict(), - 'target_model': self._target_model.state_dict(), - 'optimizer_world_model': self._optimizer_world_model.state_dict(), - 'optimizer_tokenizer': self._optimizer_tokenizer.state_dict(), - - } - - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - """ - Overview: - Load the state_dict variable into policy learn mode. - Arguments: - - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. - """ - self._learn_model.load_state_dict(state_dict['model']) - self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) - self._optimizer_tokenizer.load_state_dict(state_dict['optimizer_tokenizer']) - - def _process_transition(self, obs, policy_output, timestep): - # be compatible with DI-engine Policy class - pass - - def _get_train_sample(self, data): - # be compatible with DI-engine Policy class - pass diff --git a/lzero/policy/muzero_gpt_bkp20240304.py b/lzero/policy/muzero_gpt_bkp20240304.py deleted file mode 100644 index 135bccf99..000000000 --- a/lzero/policy/muzero_gpt_bkp20240304.py +++ /dev/null @@ -1,1103 +0,0 @@ -import copy -from collections import defaultdict -from typing import List, Dict, Any, Tuple, Union - -import numpy as np -import torch -import torch.optim as optim -from ding.model import model_wrap -from ding.policy.base_policy import Policy -from ding.torch_utils import to_tensor -from ding.utils import POLICY_REGISTRY -from torch.distributions import Categorical -from torch.nn import L1Loss - -from lzero.mcts import MuZeroMCTSCtree as MCTSCtree -from lzero.mcts import MuZeroMCTSPtree as MCTSPtree -from lzero.model import ImageTransforms -from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ - prepare_obs, prepare_obs_for_gpt, configure_optimizers - - -def configure_optimizer(model, learning_rate, weight_decay, exclude_submodules, *blacklist_module_names): - """Credits to https://github.com/karpathy/minGPT""" - # separate out all parameters to those that will and won't experience regularizing weight decay - decay = set() - no_decay = set() - whitelist_weight_modules = [torch.nn.Linear, torch.nn.Conv1d] - blacklist_weight_modules = [torch.nn.LayerNorm, torch.nn.Embedding] - - # Here, we make sure to exclude parameters from specified submodules when creating param_dict - param_dict = {} - for mn, m in model.named_modules(): - # if any(mn.startswith(module_name) for module_name in exclude_submodules): - # continue # skip parameters from excluded submodules - for pn, p in m.named_parameters(recurse=False): - fpn = f'{mn}.{pn}' if mn else pn # full param name - if not any(fpn.startswith(bl_module_name) for bl_module_name in blacklist_module_names): - param_dict[fpn] = p - if 'bias' in pn: - no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, tuple(whitelist_weight_modules)): - decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, tuple(blacklist_weight_modules)): - no_decay.add(fpn) - else: - decay.add(fpn) # Default behavior is to add to decay - - # Validate that we considered every parameter - inter_params = decay & no_decay - union_params = decay | no_decay - assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!" - assert len(param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" - - # Create the PyTorch optimizer object - optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, - ] - optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate) - return optimizer - - - -@POLICY_REGISTRY.register('muzero_gpt') -class MuZeroGPTPolicy(Policy): - """ - Overview: - The policy class for MuZero. - """ - - # The default_config for MuZero policy. - config = dict( - model=dict( - # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. - model_type='conv', # options={'mlp', 'conv'} - # (bool) If True, the action space of the environment is continuous, otherwise discrete. - continuous_action_space=False, - # (tuple) The stacked obs shape. - # observation_shape=(1, 96, 96), # if frame_stack_num=1 - observation_shape=(4, 96, 96), # if frame_stack_num=4 - # (bool) Whether to use the self-supervised learning loss. - self_supervised_learning_loss=False, - # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. - categorical_distribution=True, - # (int) The image channel in image observation. - image_channel=1, - # (int) The number of frames to stack together. - frame_stack_num=1, - # (int) The number of res blocks in MuZero model. - num_res_blocks=1, - # (int) The number of channels of hidden states in MuZero model. - num_channels=64, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, - # (bool) whether to learn bias in the last linear layer in value and policy head. - bias=True, - # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. - discrete_action_encoding_type='one_hot', - # (bool) whether to use res connection in dynamics. - res_connection_in_dynamics=True, - # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', - ), - # ****** common ****** - # (bool) whether to use rnd model. - use_rnd_model=False, - # (bool) Whether to use multi-gpu training. - multi_gpu=False, - # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) - # this variable is used in ``collector``. - sampled_algo=False, - # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) - gumbel_algo=False, - # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. - mcts_ctree=True, - # (bool) Whether to use cuda for network. - cuda=True, - # (int) The number of environments used in collecting data. - collector_env_num=8, - # (int) The number of environments used in evaluating policy. - evaluator_env_num=3, - # (str) The type of environment. Options are ['not_board_games', 'board_games']. - env_type='not_board_games', - # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. - battle_mode='play_with_bot_mode', - # (bool) Whether to monitor extra statistics in tensorboard. - monitor_extra_statistics=True, - # (int) The transition number of one ``GameSegment``. - game_segment_length=200, - - # ****** observation ****** - # (bool) Whether to transform image to string to save memory. - transform2string=False, - # (bool) Whether to use gray scale image. - gray_scale=False, - # (bool) Whether to use data augmentation. - use_augmentation=False, - # (list) The style of augmentation. - augmentation=['shift', 'intensity'], - - # ******* learn ****** - # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. - # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, - # we should set it to True to avoid the influence of the done flag. - ignore_done=False, - # (int) How many updates(iterations) to train after collector's one collection. - # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... - # For different env, we have different episode_length, - # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. - update_per_collect=None, - # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, - # (int) Minibatch size for one gradient descent. - batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam'] - optim_type='SGD', - # (float) Learning rate for training policy network. Initial lr for manually decay schedule. - learning_rate=0.2, - # (int) Frequency of target network update. - target_update_freq=100, - # (int) Frequency of target network update. - target_update_freq_for_intrinsic_reward=1000, - # (float) Weight decay for training policy network. - weight_decay=1e-4, - # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). - momentum=0.9, - # (float) The maximum constraint value of gradient norm clipping. - grad_clip_value=10, - # (int) The number of episodes in each collecting stage. - n_episode=8, - # (int) the number of simulations in MCTS. - num_simulations=50, - # (float) Discount factor (gamma) for returns. - discount_factor=0.997, - # (int) The number of steps for calculating target q_value. - td_steps=5, - # (int) The number of unroll steps in dynamics network. - num_unroll_steps=5, - # (float) The weight of reward loss. - reward_loss_weight=1, - # (float) The weight of value loss. - value_loss_weight=0.25, - # (float) The weight of policy loss. - policy_loss_weight=1, - # (float) The weight of policy entropy loss. - policy_entropy_loss_weight=0, - # (float) The weight of ssl (self-supervised learning) loss. - ssl_loss_weight=0, - # (bool) Whether to use piecewise constant learning rate decay. - # i.e. lr: 0.2 -> 0.02 -> 0.002 - lr_piecewise_constant_decay=True, - # (int) The number of final training iterations to control lr decay, which is only used for manually decay. - threshold_training_steps_for_final_lr=int(5e4), - # (bool) Whether to use manually decayed temperature. - manual_temperature_decay=False, - # (int) The number of final training iterations to control temperature, which is only used for manually decay. - threshold_training_steps_for_final_temperature=int(1e5), - # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. - # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. - fixed_temperature_value=0.25, - # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. - use_ture_chance_label_in_chance_encoder=False, - - # ****** Priority ****** - # (bool) Whether to use priority when sampling training data from the buffer. - use_priority=True, - # (float) The degree of prioritization to use. A value of 0 means no prioritization, - # while a value of 1 means full prioritization. - priority_prob_alpha=0.6, - # (float) The degree of correction to use. A value of 0 means no correction, - # while a value of 1 means full correction. - priority_prob_beta=0.4, - - # ****** UCB ****** - # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. - root_dirichlet_alpha=0.3, - # (float) The noise weight at the root node of the search tree. - root_noise_weight=0.25, - - # ****** Explore by random collect ****** - # (int) The number of episodes to collect data randomly before training. - random_collect_episode_num=0, - - # ****** Explore by eps greedy ****** - eps=dict( - # (bool) Whether to use eps greedy exploration in collecting data. - eps_greedy_exploration_in_collect=False, - # (str) The type of decaying epsilon. Options are 'linear', 'exp'. - type='linear', - # (float) The start value of eps. - start=1., - # (float) The end value of eps. - end=0.05, - # (int) The decay steps from start to end eps. - decay=int(1e5), - ), - ) - - def default_model(self) -> Tuple[str, List[str]]: - """ - Overview: - Return this algorithm default model setting for demonstration. - Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. - - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. - - import_names (:obj:`List[str]`): The model class path list used in this algorithm. - .. note:: - The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For MuZero, ``lzero.model.muzero_gpt_model.MuZeroModel`` - """ - if self._cfg.model.model_type == "conv": - # return 'MuZeroModel', ['lzero.model.muzero_gpt_model'] - return 'MuZeroModelGPT', ['lzero.model.muzero_gpt_model'] - elif self._cfg.model.model_type == "mlp": - return 'MuZeroModelGPT', ['lzero.model.muzero_gpt_model_vector_obs'] - else: - raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) - - def _init_learn(self) -> None: - """ - Overview: - Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. - """ - self._optimizer_tokenizer = optim.Adam( - self._model.tokenizer.parameters(), lr=1e-4 # weight_decay=0 - ) - - # TODO: nanoGPT optimizer - # self._optimizer_world_model = configure_optimizer( - # model=self._model.world_model, - # learning_rate=3e-3, - # # learning_rate=1e-4, - # weight_decay=self._cfg.weight_decay, - # # weight_decay=0.01, - # exclude_submodules=['tokenizer'] - # ) - self._optimizer_world_model = configure_optimizer( - model=self._model.world_model, - # learning_rate=3e-3, - learning_rate=1e-4, # NOTE: TODO - weight_decay=self._cfg.weight_decay, - # weight_decay=0.01, - exclude_submodules=['none'] # NOTE - ) - - - # use model_wrapper for specialized demands of different modes - self._target_model = copy.deepcopy(self._model) - - # TODO: torch 2.0 - self._model = torch.compile(self._model) - self._target_model = torch.compile(self._target_model) - - - # self._target_model = model_wrap( - # self._target_model, - # wrapper_name='target', - # update_type='assign', - # update_kwargs={'freq': self._cfg.target_update_freq} - # ) - # TODO: soft target - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='momentum', - # update_kwargs={'theta': 0.005} - update_kwargs={'theta': 0.01} # MOCO:0.001, DDPG:0.005, TD-MPC:0.01 - ) - self._learn_model = self._model - - # TODO: only for debug - # for param in self._learn_model.tokenizer.parameters(): - # param.requires_grad = False - - if self._cfg.use_augmentation: - self.image_transforms = ImageTransforms( - self._cfg.augmentation, - image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) - ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) - - if self._cfg.use_rnd_model: - if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': - self._target_model_for_intrinsic_reward = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} - ) - elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': - self._target_model_for_intrinsic_reward = model_wrap( - self._target_model, - wrapper_name='target', - update_type='momentum', - update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} - ) - - def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - # current_batch, target_batch, train_which_component_dict = data - if data[-1]['train_which_component'] == 'transformer': - return_loss_dict = self._forward_learn_transformer(data) - elif data[-1]['train_which_component'] == 'tokenizer': - return_loss_dict = self._forward_learn_tokenizer(data) - else: - ValueError('Unknown component type') - - return return_loss_dict - - def monitor_weights_and_grads(self, model): - for name, param in model.named_parameters(): - if param.requires_grad: - print(f"Layer: {name} | " - f"Weight mean: {param.data.mean():.4f} | " - f"Weight std: {param.data.std():.4f} | " - f"Grad mean: {param.grad.mean():.4f} | " - f"Grad std: {param.grad.std():.4f}") - - - def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - - self._learn_model.train() - self._target_model.train() - # self._learn_model.tokenizer.eval() # bug - self._learn_model.tokenizer.train() - - - if self._cfg.use_rnd_model: - self._target_model_for_intrinsic_reward.train() - - # current_batch, target_batch = data - current_batch, target_batch, train_which_component_dict = data - - - obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch - target_reward, target_value, target_policy = target_batch - - if self._cfg.model.frame_stack_num == 4: - obs_batch, obs_target_batch = prepare_obs_for_gpt(obs_batch_ori, self._cfg) - else: - obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - - - # do augmentations - if self._cfg.use_augmentation: - obs_batch = self.image_transforms.transform(obs_batch) - if self._cfg.model.self_supervised_learning_loss: - obs_target_batch = self.image_transforms.transform(obs_target_batch) - - # shape: (batch_size, num_unroll_steps, action_dim) - # NOTE: .long(), in discrete action space. - action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() - data_list = [ - mask_batch, - target_reward.astype('float32'), - target_value.astype('float32'), target_policy, weights - ] - [mask_batch, target_reward, target_value, target_policy, - weights] = to_torch_float_tensor(data_list, self._cfg.device) - - target_reward = target_reward.view(self._cfg.batch_size, -1) - target_value = target_value.view(self._cfg.batch_size, -1) - - assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) - - # ``scalar_transform`` to transform the original value to the scaled value, - # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - transformed_target_reward = scalar_transform(target_reward) - transformed_target_value = scalar_transform(target_value) - - # transform a scalar to its categorical_distribution. After this transformation, each scalar is - # represented as the linear combination of its two adjacent supports. - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) - - # compute_loss(self, batch: Batch, tokenizer: Tokenizer, ** kwargs: Any) - - batch_for_gpt = {} - # TODO: for cartpole self._cfg.model.observation_shape - if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape)==1: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - elif len(self._cfg.model.observation_shape)==3: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, *self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - - - batch_for_gpt['actions'] = action_batch.squeeze(-1) # (B, T-1, A) -> (B, T-1) - - batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] # (B, T, R) -> (B, T-1, R) - - batch_for_gpt['mask_padding'] = mask_batch == 1.0 # (B, T) NOTE: 0 means invalid padding data - batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] # (B, T-1) TODO - - - batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # (B, T-1, O) or (B, T-1, C, H, W) - batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) # (B, T-1) - - batch_for_gpt['target_value'] = target_value_categorical[:, :-1] # (B, T-1, V) - batch_for_gpt['target_policy'] = target_policy[:, :-1] # (B, T-1, A) - # NOTE: TODO: next latent state's policy value - # batch_for_gpt['target_value'] = target_value_categorical[:, 1:] # (B, T-1, V) - # batch_for_gpt['target_policy'] = target_policy[:, 1:] # (B, T-1, A) - - # self._learn_model.world_model.train() - - # get valid target_policy data - valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] - # compute entropy of each policy - target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) - # compute average entropy - average_target_policy_entropy = target_policy_entropy.mean().item() - # print(f'Average entropy: {average_entropy}') - - - # if train_which_component_dict['train_which_component'] == 'transformer': - # ============================================================== - # update world model - # ============================================================== - intermediate_losses = defaultdict(float) - # losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._learn_model.tokenizer) - losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._target_model.world_model.tokenizer) - - weighted_total_loss = losses.loss_total - for loss_name, loss_value in losses.intermediate_losses.items(): - intermediate_losses[f"{loss_name}"] = loss_value - # print(intermediate_losses) - obs_loss = intermediate_losses['loss_obs'] - reward_loss = intermediate_losses['loss_rewards'] - policy_loss = intermediate_losses['loss_policy'] - value_loss = intermediate_losses['loss_value'] - latent_kl_loss = intermediate_losses['latent_kl_loss'] - latent_recon_loss = intermediate_losses['latent_recon_loss'] - perceptual_loss = intermediate_losses['perceptual_loss'] - - - # ============================================================== - # the core learn model update step. - # ============================================================== - """ - for name, parameter in self._learn_model.tokenizer.named_parameters(): - print(name) - """ - gradient_scale = 1 / self._cfg.num_unroll_steps - # TODO(pu): test the effect of gradient scale. - weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) - self._optimizer_world_model.zero_grad() - weighted_total_loss.backward() - - # 在训练循环中使用 - # self.monitor_weights_and_grads(self._learn_model.tokenizer.representation_network) - # print('torch.cuda.memory_summary():', torch.cuda.memory_summary()) - - if self._cfg.multi_gpu: - self.sync_gradients(self._learn_model) - total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( - self._learn_model.world_model.parameters(), self._cfg.grad_clip_value - ) - total_grad_norm_before_clip_rep_net = torch.nn.utils.clip_grad_norm_(self._learn_model.tokenizer.representation_network.parameters(), max_norm=1.0) - # print('total_grad_norm_before_clip_rep_net:', total_grad_norm_before_clip_rep_net) - - - self._optimizer_world_model.step() - if self._cfg.lr_piecewise_constant_decay: - self.lr_scheduler.step() - - - # ============================================================== - # the core target model update step. - # ============================================================== - self._target_model.update(self._learn_model.state_dict()) - if self._cfg.use_rnd_model: - self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) - - - # 确保所有的CUDA核心完成工作,以便准确统计显存使用情况 - torch.cuda.synchronize() - # 获取当前分配的显存总量(字节) - current_memory_allocated = torch.cuda.memory_allocated() - # 获取程序运行到目前为止分配过的最大显存量(字节) - max_memory_allocated = torch.cuda.max_memory_allocated() - - # 将显存使用量从字节转换为GB - current_memory_allocated_gb = current_memory_allocated / (1024**3) - max_memory_allocated_gb = max_memory_allocated / (1024**3) - # 使用SummaryWriter记录当前和最大显存使用量 - - - return_loss_dict = { - 'Current_GPU': current_memory_allocated_gb, - 'Max_GPU': max_memory_allocated_gb, - 'collect_mcts_temperature': self._collect_mcts_temperature, - 'collect_epsilon': self.collect_epsilon, - 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], - 'cur_lr_tokenizer': self._optimizer_tokenizer.param_groups[0]['lr'], - - 'weighted_total_loss': weighted_total_loss.item(), - 'obs_loss': obs_loss, - 'latent_kl_loss': latent_kl_loss, - 'latent_recon_loss':latent_recon_loss, - 'perceptual_loss':perceptual_loss, - 'policy_loss': policy_loss, - 'target_policy_entropy': average_target_policy_entropy, - # 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), - 'reward_loss': reward_loss, - 'value_loss': value_loss, - # 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, - - # ============================================================== - # priority related - # ============================================================== - # 'value_priority_orig': value_priority, - 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO - # 'value_priority': value_priority.mean().item(), - 'target_reward': target_reward.detach().cpu().numpy().mean().item(), - 'target_value': target_value.detach().cpu().numpy().mean().item(), - 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), - 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), - # 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), - # 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), - 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), - 'total_grad_norm_before_clip_rep_net': total_grad_norm_before_clip_rep_net.item(), - } - - return return_loss_dict - - - def _forward_learn_tokenizer(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - self._learn_model.train() - self._target_model.train() - if self._cfg.use_rnd_model: - self._target_model_for_intrinsic_reward.train() - - # current_batch, target_batch = data - current_batch, target_batch, train_which_component_dict = data - - - obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch - target_reward, target_value, target_policy = target_batch - - obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - - # do augmentations - if self._cfg.use_augmentation: - obs_batch = self.image_transforms.transform(obs_batch) - if self._cfg.model.self_supervised_learning_loss: - obs_target_batch = self.image_transforms.transform(obs_target_batch) - - # shape: (batch_size, num_unroll_steps, action_dim) - # NOTE: .long(), in discrete action space. - # action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() - # data_list = [ - # mask_batch, - # target_reward.astype('float32'), - # target_value.astype('float32'), target_policy, weights - # ] - - # [mask_batch, target_reward, target_value, target_policy, - # weights] = to_torch_float_tensor(data_list, self._cfg.device) - - # target_reward = target_reward.view(self._cfg.batch_size, -1) - # target_value = target_value.view(self._cfg.batch_size, -1) - - # assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) - - - batch_for_gpt = {} - # TODO: for cartpole self._cfg.model.observation_shape - if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape)==1: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - elif len(self._cfg.model.observation_shape)==3: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, *self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - - batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # (B, T-1, O) or (B, T-1, C, H, W) - - # if train_which_component_dict['train_which_component'] == 'tokenizer': - - # ============================================================== - # update tokenizer - # ============================================================== - # TODO: train tokenlizer - self._learn_model.tokenizer.train() - - # for name, param in self._learn_model.tokenizer.named_parameters(): - # if param.requires_grad: - # print(name, param.shape) - - losses_tokenizer = self._learn_model.tokenizer.compute_loss(batch_for_gpt) - - self._optimizer_tokenizer.zero_grad() - - weighted_total_loss_tokenizer = losses_tokenizer.loss_total - weighted_total_loss_tokenizer.backward() - # losses_tokenizer.loss_total.backward() - - total_grad_norm_before_clip_tokenizer = torch.nn.utils.clip_grad_norm_( - self._learn_model.tokenizer.parameters(), self._cfg.grad_clip_value - ) - - - self._optimizer_tokenizer.step() - - intermediate_losses_tokenizer= defaultdict(float) - for loss_name, loss_value in losses_tokenizer.intermediate_losses.items(): - intermediate_losses_tokenizer[f"{loss_name}"] = loss_value - # print(intermediate_losses) - commitment_loss= intermediate_losses_tokenizer['commitment_loss'] - reconstruction_loss = intermediate_losses_tokenizer['reconstruction_loss'] - perceptual_loss = intermediate_losses_tokenizer['perceptual_loss'] - - - # # ============================================================== - # # the core target model update step. - # # ============================================================== - # self._target_model.update(self._learn_model.state_dict()) - # if self._cfg.use_rnd_model: - # self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) - - - return_loss_dict = { - 'collect_mcts_temperature': self._collect_mcts_temperature, - 'collect_epsilon': self.collect_epsilon, - 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], - 'cur_lr_tokenizer': self._optimizer_tokenizer.param_groups[0]['lr'], - - # 'weighted_total_loss': weighted_total_loss.item(), - # 'obs_loss': obs_loss, - # 'policy_loss': policy_loss, - # 'target_policy_entropy': average_target_policy_entropy, - # 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), - # 'reward_loss': reward_loss, - # 'value_loss': value_loss, - # 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, - - # ============================================================== - # priority related - # ============================================================== - # 'value_priority_orig': value_priority, - # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO - # 'value_priority': value_priority.mean().item(), - # 'target_reward': target_reward.detach().cpu().numpy().mean().item(), - # 'target_value': target_value.detach().cpu().numpy().mean().item(), - # 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), - # 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), - # 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), - # 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), - 'total_grad_norm_before_clip_tokenizer': total_grad_norm_before_clip_tokenizer.item(), - 'commitment_loss':commitment_loss, - 'reconstruction_loss':reconstruction_loss, - 'perceptual_loss': perceptual_loss, - } - - return return_loss_dict - - def _init_collect(self) -> None: - """ - Overview: - Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. - """ - self._collect_model = self._model - - if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) - else: - self._mcts_collect = MCTSPtree(self._cfg) - self._collect_mcts_temperature = 1. - self.collect_epsilon = 0.0 - - def _forward_collect( - self, - data: torch.Tensor, - action_mask: list = None, - temperature: float = 1, - to_play: List = [-1], - epsilon: float = 0.25, - ready_env_id: np.array = None, - ) -> Dict: - """ - Overview: - The forward function for collecting data in collect mode. Use model to execute MCTS search. - Choosing the action through sampling during the collect mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - temperature (:obj:`float`): The temperature of the policy. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - temperature: :math:`(1, )`. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._collect_model.eval() - self._collect_mcts_temperature = temperature - self.collect_epsilon = epsilon - active_collect_env_num = data.shape[0] - # if active_collect_env_num == 1: - # print('debug') - with torch.no_grad(): - network_output = self._collect_model.initial_inference(data) - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() - latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() - - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] - # the only difference between collect and eval is the dirichlet noise - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(active_collect_env_num, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(active_collect_env_num, legal_actions) - - roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) - self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) - - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` - roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} - - data_id = [i for i in range(active_collect_env_num)] - output = {i: None for i in data_id} - - if ready_env_id is None: - ready_env_id = np.arange(active_collect_env_num) - - for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - if self._cfg.eps.eps_greedy_exploration_in_collect: - # eps greedy collect - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=True - ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - if np.random.rand() < self.collect_epsilon: - action = np.random.choice(legal_actions[i]) - else: - # normal collect - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=False - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - } - - - return output - - - def _init_eval(self) -> None: - """ - Overview: - Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. - """ - self._eval_model = self._model - if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) - else: - self._mcts_eval = MCTSPtree(self._cfg) - - def _get_target_obs_index_in_step_k(self, step): - """ - Overview: - Get the begin index and end index of the target obs in step k. - Arguments: - - step (:obj:`int`): The current step k. - Returns: - - beg_index (:obj:`int`): The begin index of the target obs in step k. - - end_index (:obj:`int`): The end index of the target obs in step k. - Examples: - >>> self._cfg.model.model_type = 'conv' - >>> self._cfg.model.image_channel = 3 - >>> self._cfg.model.frame_stack_num = 4 - >>> self._get_target_obs_index_in_step_k(0) - >>> (0, 12) - """ - if self._cfg.model.model_type == 'conv': - beg_index = self._cfg.model.image_channel * step - end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) - elif self._cfg.model.model_type == 'mlp': - beg_index = self._cfg.model.observation_shape * step - end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) - return beg_index, end_index - - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict: - """ - Overview: - The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. - Choosing the action with the highest value (argmax) rather than sampling during the eval mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._eval_model.eval() - active_eval_env_num = data.shape[0] - with torch.no_grad(): - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - # network_output = self._collect_model.initial_inference(data) - network_output = self._eval_model.initial_inference(data) - latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - - if not self._eval_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) - latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) - - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(active_eval_env_num, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(active_eval_env_num, legal_actions) - roots.prepare_no_noise(reward_roots, policy_logits, to_play) - self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) - - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` - roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} - - data_id = [i for i in range(active_eval_env_num)] - output = {i: None for i in data_id} - - if ready_env_id is None: - ready_env_id = np.arange(active_eval_env_num) - - for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than - # sampling during the evaluation phase. - - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=1, deterministic=True - ) - # TODO: eval - # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - # distributions, temperature=self._collect_mcts_temperature, deterministic=False - # ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the - # entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - } - - return output - - def _monitor_vars_learn(self) -> List[str]: - """ - Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in - tensorboard according to the return value ``_forward_learn``. - """ - return [ - 'Current_GPU', - 'Max_GPU', - 'collect_epsilon', - 'collect_mcts_temperature', - # 'cur_lr', - 'cur_lr_world_model', - 'cur_lr_tokenizer', - - 'weighted_total_loss', - # 'total_loss', - 'obs_loss', - 'policy_loss', - 'latent_kl_loss', - 'latent_recon_loss', - # 'policy_entropy', - 'target_policy_entropy', - 'reward_loss', - 'value_loss', - 'consistency_loss', - 'value_priority', - 'target_reward', - 'target_value', - # 'predicted_rewards', - # 'predicted_values', - # 'transformed_target_reward', - # 'transformed_target_value', - 'total_grad_norm_before_clip_tokenizer', - 'total_grad_norm_before_clip_wm', - 'total_grad_norm_before_clip_rep_net', - # tokenizer - 'commitment_loss', - 'reconstruction_loss', - 'perceptual_loss', - ] - - def _state_dict_learn(self) -> Dict[str, Any]: - """ - Overview: - Return the state_dict of learn mode, usually including model, target_model and optimizer. - Returns: - - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. - """ - return { - 'model': self._learn_model.state_dict(), - 'target_model': self._target_model.state_dict(), - 'optimizer_world_model': self._optimizer_world_model.state_dict(), - 'optimizer_tokenizer': self._optimizer_tokenizer.state_dict(), - } - - # TODO: - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - """ - Overview: - Load the state_dict variable into policy learn mode. - Arguments: - - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. - """ - self._learn_model.load_state_dict(state_dict['model']) - self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) - self._optimizer_tokenizer.load_state_dict(state_dict['optimizer_tokenizer']) - - # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - # """ - # Overview: - # Load the state_dict variable into policy learn mode, specifically loading only the - # representation network of the tokenizer within model and target_model. - # Arguments: - # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. - # """ - # # Extract the relevant sub-state-dicts for representation_network from the state_dict - # # model_rep_network_state = state_dict['model']['tokenizer']['representation_network'] - # # target_model_rep_network_state = state_dict['target_model']['tokenizer']['representation_network'] - - # # # Load the state into the model's representation network - # # self._learn_model.tokenizer.representation_network.load_state_dict(model_rep_network_state) - # # self._target_model.tokenizer.representation_network.load_state_dict(target_model_rep_network_state) - - # # Assuming self._learn_model and self._target_model have a 'representation_network' submodule - # self._load_representation_network_state(state_dict['model'], self._learn_model.tokenizer.representation_network) - # self._load_representation_network_state(state_dict['target_model'], self._target_model.tokenizer.representation_network) - - - def _load_representation_network_state(self, state_dict, model_submodule): - """ - This function filters the state_dict to only include the state of the representation_network - and loads it into the given model submodule. - """ - from collections import OrderedDict - - # Filter the state_dict to only include keys that start with 'representation_network' - representation_network_keys = {k: v for k, v in state_dict.items() if k.startswith('representation_network')} - - # Load the state into the model's representation_network submodule - # model_submodule.load_state_dict(OrderedDict(representation_network_keys)) - - # 去掉键名前缀 - new_state_dict = OrderedDict() - for key, value in representation_network_keys.items(): - new_key = key.replace('representation_network.', '') # 去掉前缀 - new_state_dict[new_key] = value - - # # 如果模型在特定的设备上,确保状态字典也在那个设备上 - # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # new_state_dict = {key: value.to(device) for key, value in new_state_dict.items()} - - # 尝试加载状态字典 - try: - # model_submodule.load_state_dict(new_state_dict) - # 使用 strict=False 参数忽略缺少的键 - model_submodule.load_state_dict(new_state_dict, strict=False) - except RuntimeError as e: - print("加载失败: ", e) - - - def _process_transition(self, obs, policy_output, timestep): - # be compatible with DI-engine Policy class - pass - - def _get_train_sample(self, data): - # be compatible with DI-engine Policy class - pass diff --git a/lzero/policy/muzero_gpt_bkp20240307.py b/lzero/policy/muzero_gpt_bkp20240307.py deleted file mode 100644 index fdd43f43c..000000000 --- a/lzero/policy/muzero_gpt_bkp20240307.py +++ /dev/null @@ -1,1152 +0,0 @@ -import copy -from collections import defaultdict -from typing import List, Dict, Any, Tuple, Union - -import numpy as np -import torch -import torch.optim as optim -from ding.model import model_wrap -from ding.policy.base_policy import Policy -from ding.torch_utils import to_tensor -from ding.utils import POLICY_REGISTRY -from torch.distributions import Categorical -from torch.nn import L1Loss -import inspect -from lzero.mcts import MuZeroMCTSCtree as MCTSCtree -from lzero.mcts import MuZeroMCTSPtree as MCTSPtree -from lzero.model import ImageTransforms -from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ - prepare_obs, prepare_obs_for_gpt - - -# def configure_optimizer(model, learning_rate, weight_decay, exclude_submodules, *blacklist_module_names): -# """Credits to https://github.com/karpathy/minGPT""" -# # separate out all parameters to those that will and won't experience regularizing weight decay -# decay = set() -# no_decay = set() -# whitelist_weight_modules = [torch.nn.Linear, torch.nn.Conv1d] -# blacklist_weight_modules = [torch.nn.LayerNorm, torch.nn.Embedding] - -# # Here, we make sure to exclude parameters from specified submodules when creating param_dict -# param_dict = {} -# for mn, m in model.named_modules(): -# if any(mn.startswith(module_name) for module_name in exclude_submodules): -# continue # skip parameters from excluded submodules -# for pn, p in m.named_parameters(recurse=False): -# fpn = f'{mn}.{pn}' if mn else pn # full param name -# if not any(fpn.startswith(bl_module_name) for bl_module_name in blacklist_module_names): -# param_dict[fpn] = p -# if 'bias' in pn: -# no_decay.add(fpn) -# elif pn.endswith('weight') and isinstance(m, tuple(whitelist_weight_modules)): -# decay.add(fpn) -# elif pn.endswith('weight') and isinstance(m, tuple(blacklist_weight_modules)): -# no_decay.add(fpn) -# else: -# decay.add(fpn) # Default behavior is to add to decay - -# # Validate that we considered every parameter -# inter_params = decay & no_decay -# union_params = decay | no_decay -# assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!" -# assert len(param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" - -# # Create the PyTorch optimizer object -# optim_groups = [ -# {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, -# {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, -# ] -# optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate) -# return optimizer - -def configure_optimizers(model, weight_decay, learning_rate, betas, device_type): - # start with all of the candidate parameters - param_dict = {pn: p for pn, p in model.named_parameters()} - # filter out those that do not require grad - param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} - # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. - # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. - decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] - nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] - optim_groups = [ - {'params': decay_params, 'weight_decay': weight_decay}, - {'params': nodecay_params, 'weight_decay': 0.0} - ] - num_decay_params = sum(p.numel() for p in decay_params) - num_nodecay_params = sum(p.numel() for p in nodecay_params) - print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") - print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") - # Create AdamW optimizer and use the fused version if it is available - fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters - use_fused = fused_available and device_type == 'cuda' - extra_args = dict(fused=True) if use_fused else dict() - optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) - print(f"using fused AdamW: {use_fused}") - - return optimizer - -@POLICY_REGISTRY.register('muzero_gpt') -class MuZeroGPTPolicy(Policy): - """ - Overview: - The policy class for MuZero. - """ - - # The default_config for MuZero policy. - config = dict( - model=dict( - # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. - model_type='conv', # options={'mlp', 'conv'} - # (bool) If True, the action space of the environment is continuous, otherwise discrete. - continuous_action_space=False, - # (tuple) The stacked obs shape. - # observation_shape=(1, 96, 96), # if frame_stack_num=1 - observation_shape=(4, 96, 96), # if frame_stack_num=4 - # (bool) Whether to use the self-supervised learning loss. - self_supervised_learning_loss=False, - # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. - categorical_distribution=True, - # (int) The image channel in image observation. - image_channel=1, - # (int) The number of frames to stack together. - frame_stack_num=1, - # (int) The number of res blocks in MuZero model. - num_res_blocks=1, - # (int) The number of channels of hidden states in MuZero model. - num_channels=64, - # (int) The scale of supports used in categorical distribution. - # This variable is only effective when ``categorical_distribution=True``. - support_scale=300, - # (bool) whether to learn bias in the last linear layer in value and policy head. - bias=True, - # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. - discrete_action_encoding_type='one_hot', - # (bool) whether to use res connection in dynamics. - res_connection_in_dynamics=True, - # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', - ), - # ****** common ****** - # (bool) whether to use rnd model. - use_rnd_model=False, - # (bool) Whether to use multi-gpu training. - multi_gpu=False, - # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) - # this variable is used in ``collector``. - sampled_algo=False, - # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) - gumbel_algo=False, - # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. - mcts_ctree=True, - # (bool) Whether to use cuda for network. - cuda=True, - # (int) The number of environments used in collecting data. - collector_env_num=8, - # (int) The number of environments used in evaluating policy. - evaluator_env_num=3, - # (str) The type of environment. Options are ['not_board_games', 'board_games']. - env_type='not_board_games', - # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. - battle_mode='play_with_bot_mode', - # (bool) Whether to monitor extra statistics in tensorboard. - monitor_extra_statistics=True, - # (int) The transition number of one ``GameSegment``. - game_segment_length=200, - - # ****** observation ****** - # (bool) Whether to transform image to string to save memory. - transform2string=False, - # (bool) Whether to use gray scale image. - gray_scale=False, - # (bool) Whether to use data augmentation. - use_augmentation=False, - # (list) The style of augmentation. - augmentation=['shift', 'intensity'], - - # ******* learn ****** - # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. - # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, - # we should set it to True to avoid the influence of the done flag. - ignore_done=False, - # (int) How many updates(iterations) to train after collector's one collection. - # Bigger "update_per_collect" means bigger off-policy. - # collect data -> update policy-> collect data -> ... - # For different env, we have different episode_length, - # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. - update_per_collect=None, - # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.1, - # (int) Minibatch size for one gradient descent. - batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam'] - optim_type='SGD', - # (float) Learning rate for training policy network. Initial lr for manually decay schedule. - learning_rate=0.2, - # (int) Frequency of target network update. - target_update_freq=100, - # (int) Frequency of target network update. - target_update_freq_for_intrinsic_reward=1000, - # (float) Weight decay for training policy network. - weight_decay=1e-4, - # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). - momentum=0.9, - # (float) The maximum constraint value of gradient norm clipping. - grad_clip_value=10, - # (int) The number of episodes in each collecting stage. - n_episode=8, - # (int) the number of simulations in MCTS. - num_simulations=50, - # (float) Discount factor (gamma) for returns. - discount_factor=0.997, - # (int) The number of steps for calculating target q_value. - td_steps=5, - # (int) The number of unroll steps in dynamics network. - num_unroll_steps=5, - # (float) The weight of reward loss. - reward_loss_weight=1, - # (float) The weight of value loss. - value_loss_weight=0.25, - # (float) The weight of policy loss. - policy_loss_weight=1, - # (float) The weight of policy entropy loss. - policy_entropy_loss_weight=0, - # (float) The weight of ssl (self-supervised learning) loss. - ssl_loss_weight=0, - # (bool) Whether to use piecewise constant learning rate decay. - # i.e. lr: 0.2 -> 0.02 -> 0.002 - lr_piecewise_constant_decay=True, - # (int) The number of final training iterations to control lr decay, which is only used for manually decay. - threshold_training_steps_for_final_lr=int(5e4), - # (bool) Whether to use manually decayed temperature. - manual_temperature_decay=False, - # (int) The number of final training iterations to control temperature, which is only used for manually decay. - threshold_training_steps_for_final_temperature=int(1e5), - # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. - # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. - fixed_temperature_value=0.25, - # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. - use_ture_chance_label_in_chance_encoder=False, - - # ****** Priority ****** - # (bool) Whether to use priority when sampling training data from the buffer. - use_priority=True, - # (float) The degree of prioritization to use. A value of 0 means no prioritization, - # while a value of 1 means full prioritization. - priority_prob_alpha=0.6, - # (float) The degree of correction to use. A value of 0 means no correction, - # while a value of 1 means full correction. - priority_prob_beta=0.4, - - # ****** UCB ****** - # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. - root_dirichlet_alpha=0.3, - # (float) The noise weight at the root node of the search tree. - root_noise_weight=0.25, - - # ****** Explore by random collect ****** - # (int) The number of episodes to collect data randomly before training. - random_collect_episode_num=0, - - # ****** Explore by eps greedy ****** - eps=dict( - # (bool) Whether to use eps greedy exploration in collecting data. - eps_greedy_exploration_in_collect=False, - # (str) The type of decaying epsilon. Options are 'linear', 'exp'. - type='linear', - # (float) The start value of eps. - start=1., - # (float) The end value of eps. - end=0.05, - # (int) The decay steps from start to end eps. - decay=int(1e5), - ), - ) - - def default_model(self) -> Tuple[str, List[str]]: - """ - Overview: - Return this algorithm default model setting for demonstration. - Returns: - - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. - - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. - - import_names (:obj:`List[str]`): The model class path list used in this algorithm. - .. note:: - The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For MuZero, ``lzero.model.muzero_gpt_model.MuZeroModel`` - """ - if self._cfg.model.model_type == "conv": - # return 'MuZeroModel', ['lzero.model.muzero_gpt_model'] - return 'MuZeroModelGPT', ['lzero.model.muzero_gpt_model'] - elif self._cfg.model.model_type == "mlp": - return 'MuZeroModelGPT', ['lzero.model.muzero_gpt_model_vector_obs'] - else: - raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) - - def _init_learn(self) -> None: - """ - Overview: - Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. - """ - self._optimizer_tokenizer = optim.Adam( - self._model.tokenizer.parameters(), lr=1e-4 # weight_decay=0 - ) - - # TODO: nanoGPT optimizer - # self._optimizer_world_model = configure_optimizer( - # model=self._model.world_model, - # learning_rate=3e-3, - # # learning_rate=1e-4, - # weight_decay=self._cfg.weight_decay, - # # weight_decay=0.01, - # exclude_submodules=['tokenizer'] - # ) - - # 验证没有问题的版本 - # self._optimizer_world_model = configure_optimizer( - # model=self._model.world_model, - # # learning_rate=3e-3, - # learning_rate=1e-4, # NOTE: TODO - # weight_decay=self._cfg.weight_decay, - # # weight_decay=0.01, - # exclude_submodules=['none'] # NOTE - # ) - - self._optimizer_world_model = configure_optimizers( - model=self._model.world_model, - learning_rate=1e-4, - weight_decay=self._cfg.weight_decay, - device_type=self._cfg.device, - betas=(0.9, 0.95), - ) - - # use model_wrapper for specialized demands of different modes - self._target_model = copy.deepcopy(self._model) - - # TODO: torch 2.0 - self._model = torch.compile(self._model) - self._target_model = torch.compile(self._target_model) - - - # self._target_model = model_wrap( - # self._target_model, - # wrapper_name='target', - # update_type='assign', - # update_kwargs={'freq': self._cfg.target_update_freq} - # ) - # TODO: soft target - self._target_model = model_wrap( - self._target_model, - wrapper_name='target', - update_type='momentum', - # update_kwargs={'theta': 0.005} - update_kwargs={'theta': 0.01} # MOCO:0.001, DDPG:0.005, TD-MPC:0.01 - ) - self._learn_model = self._model - - # TODO: only for debug - # for param in self._learn_model.tokenizer.parameters(): - # param.requires_grad = False - - if self._cfg.use_augmentation: - self.image_transforms = ImageTransforms( - self._cfg.augmentation, - image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) - ) - self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) - self.inverse_scalar_transform_handle = InverseScalarTransform( - self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution - ) - - if self._cfg.use_rnd_model: - if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': - self._target_model_for_intrinsic_reward = model_wrap( - self._target_model, - wrapper_name='target', - update_type='assign', - update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} - ) - elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': - self._target_model_for_intrinsic_reward = model_wrap( - self._target_model, - wrapper_name='target', - update_type='momentum', - update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} - ) - - def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - # current_batch, target_batch, train_which_component_dict = data - if data[-1]['train_which_component'] == 'transformer': - return_loss_dict = self._forward_learn_transformer(data) - elif data[-1]['train_which_component'] == 'tokenizer': - return_loss_dict = self._forward_learn_tokenizer(data) - else: - ValueError('Unknown component type') - - return return_loss_dict - - def monitor_weights_and_grads(self, model): - for name, param in model.named_parameters(): - if param.requires_grad: - print(f"Layer: {name} | " - f"Weight mean: {param.data.mean():.4f} | " - f"Weight std: {param.data.std():.4f} | " - f"Grad mean: {param.grad.mean():.4f} | " - f"Grad std: {param.grad.std():.4f}") - - - def _forward_learn_transformer(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - - self._learn_model.train() - self._target_model.train() - # self._learn_model.tokenizer.train() - # self._eval_model.world_model.transformer.train() # TODO - - - if self._cfg.use_rnd_model: - self._target_model_for_intrinsic_reward.train() - - # current_batch, target_batch = data - current_batch, target_batch, train_which_component_dict = data - - - obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch - target_reward, target_value, target_policy = target_batch - - if self._cfg.model.frame_stack_num == 4: - obs_batch, obs_target_batch = prepare_obs_for_gpt(obs_batch_ori, self._cfg) - else: - obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - - - # do augmentations - if self._cfg.use_augmentation: - obs_batch = self.image_transforms.transform(obs_batch) - if self._cfg.model.self_supervised_learning_loss: - obs_target_batch = self.image_transforms.transform(obs_target_batch) - - # shape: (batch_size, num_unroll_steps, action_dim) - # NOTE: .long(), in discrete action space. - action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() - data_list = [ - mask_batch, - target_reward.astype('float32'), - target_value.astype('float32'), target_policy, weights - ] - [mask_batch, target_reward, target_value, target_policy, - weights] = to_torch_float_tensor(data_list, self._cfg.device) - - target_reward = target_reward.view(self._cfg.batch_size, -1) - target_value = target_value.view(self._cfg.batch_size, -1) - - assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) - - # ``scalar_transform`` to transform the original value to the scaled value, - # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - transformed_target_reward = scalar_transform(target_reward) - transformed_target_value = scalar_transform(target_value) - - # transform a scalar to its categorical_distribution. After this transformation, each scalar is - # represented as the linear combination of its two adjacent supports. - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) - - # compute_loss(self, batch: Batch, tokenizer: Tokenizer, ** kwargs: Any) - - batch_for_gpt = {} - # TODO: for cartpole self._cfg.model.observation_shape - if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape)==1: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - elif len(self._cfg.model.observation_shape)==3: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, *self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - - - batch_for_gpt['actions'] = action_batch.squeeze(-1) # (B, T-1, A) -> (B, T-1) - - batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] # (B, T, R) -> (B, T-1, R) - - batch_for_gpt['mask_padding'] = mask_batch == 1.0 # (B, T) NOTE: 0 means invalid padding data - batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] # (B, T-1) TODO - - - batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # (B, T-1, O) or (B, T-1, C, H, W) - batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) # (B, T-1) - - batch_for_gpt['target_value'] = target_value_categorical[:, :-1] # (B, T-1, V) - batch_for_gpt['target_policy'] = target_policy[:, :-1] # (B, T-1, A) - # NOTE: TODO: next latent state's policy value - # batch_for_gpt['target_value'] = target_value_categorical[:, 1:] # (B, T-1, V) - # batch_for_gpt['target_policy'] = target_policy[:, 1:] # (B, T-1, A) - - # self._learn_model.world_model.train() - - # get valid target_policy data - valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] - # compute entropy of each policy - target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) - # compute average entropy - average_target_policy_entropy = target_policy_entropy.mean().item() - # print(f'Average entropy: {average_entropy}') - - - # if train_which_component_dict['train_which_component'] == 'transformer': - # ============================================================== - # update world model - # ============================================================== - intermediate_losses = defaultdict(float) - # losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._learn_model.tokenizer) - losses = self._learn_model.world_model.compute_loss(batch_for_gpt, self._target_model.world_model.tokenizer) - - weighted_total_loss = losses.loss_total - for loss_name, loss_value in losses.intermediate_losses.items(): - intermediate_losses[f"{loss_name}"] = loss_value - # print(intermediate_losses) - obs_loss = intermediate_losses['loss_obs'] - reward_loss = intermediate_losses['loss_rewards'] - policy_loss = intermediate_losses['loss_policy'] - value_loss = intermediate_losses['loss_value'] - latent_kl_loss = intermediate_losses['latent_kl_loss'] - latent_recon_loss = intermediate_losses['latent_recon_loss'] - perceptual_loss = intermediate_losses['perceptual_loss'] - - - # ============================================================== - # the core learn model update step. - # ============================================================== - """ - for name, parameter in self._learn_model.tokenizer.named_parameters(): - print(name) - """ - gradient_scale = 1 / self._cfg.num_unroll_steps - # TODO(pu): test the effect of gradient scale. - weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) - self._optimizer_world_model.zero_grad() - weighted_total_loss.backward() - - # 在训练循环中使用 - # self.monitor_weights_and_grads(self._learn_model.tokenizer.representation_network) - # print('torch.cuda.memory_summary():', torch.cuda.memory_summary()) - - if self._cfg.multi_gpu: - self.sync_gradients(self._learn_model) - total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value) - # TODO - # total_grad_norm_before_clip_rep_net = torch.nn.utils.clip_grad_norm_(self._learn_model.tokenizer.representation_network.parameters(), max_norm=1.0) - - # print('total_grad_norm_before_clip_rep_net:', total_grad_norm_before_clip_rep_net) - - - self._optimizer_world_model.step() - if self._cfg.lr_piecewise_constant_decay: - self.lr_scheduler.step() - - - # ============================================================== - # the core target model update step. - # ============================================================== - self._target_model.update(self._learn_model.state_dict()) - if self._cfg.use_rnd_model: - self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) - - - # 确保所有的CUDA核心完成工作,以便准确统计显存使用情况 - torch.cuda.synchronize() - # 获取当前分配的显存总量(字节) - current_memory_allocated = torch.cuda.memory_allocated() - # 获取程序运行到目前为止分配过的最大显存量(字节) - max_memory_allocated = torch.cuda.max_memory_allocated() - - # 将显存使用量从字节转换为GB - current_memory_allocated_gb = current_memory_allocated / (1024**3) - max_memory_allocated_gb = max_memory_allocated / (1024**3) - # 使用SummaryWriter记录当前和最大显存使用量 - - - return_loss_dict = { - 'Current_GPU': current_memory_allocated_gb, - 'Max_GPU': max_memory_allocated_gb, - 'collect_mcts_temperature': self._collect_mcts_temperature, - 'collect_epsilon': self.collect_epsilon, - 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], - 'cur_lr_tokenizer': self._optimizer_tokenizer.param_groups[0]['lr'], - - 'weighted_total_loss': weighted_total_loss.item(), - 'obs_loss': obs_loss, - 'latent_kl_loss': latent_kl_loss, - 'latent_recon_loss':latent_recon_loss, - 'perceptual_loss':perceptual_loss, - 'policy_loss': policy_loss, - 'target_policy_entropy': average_target_policy_entropy, - # 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), - 'reward_loss': reward_loss, - 'value_loss': value_loss, - # 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, - - # ============================================================== - # priority related - # ============================================================== - # 'value_priority_orig': value_priority, - 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO - # 'value_priority': value_priority.mean().item(), - 'target_reward': target_reward.mean().item(), - 'target_value': target_value.mean().item(), - 'transformed_target_reward': transformed_target_reward.mean().item(), - 'transformed_target_value': transformed_target_value.mean().item(), - # 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), - # 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), - 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), - 'total_grad_norm_before_clip_rep_net': total_grad_norm_before_clip_rep_net.item(), - } - - return return_loss_dict - - - def _forward_learn_tokenizer(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: - """ - Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. - The loss is calculated by the loss function and the loss is backpropagated to update the model. - Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. - The first tensor is the current_batch, the second tensor is the target_batch. - Returns: - - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ - current learning loss and learning statistics. - """ - self._learn_model.train() - self._target_model.train() - if self._cfg.use_rnd_model: - self._target_model_for_intrinsic_reward.train() - - # current_batch, target_batch = data - current_batch, target_batch, train_which_component_dict = data - - - obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch - target_reward, target_value, target_policy = target_batch - - obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - - # do augmentations - if self._cfg.use_augmentation: - obs_batch = self.image_transforms.transform(obs_batch) - if self._cfg.model.self_supervised_learning_loss: - obs_target_batch = self.image_transforms.transform(obs_target_batch) - - # shape: (batch_size, num_unroll_steps, action_dim) - # NOTE: .long(), in discrete action space. - # action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() - # data_list = [ - # mask_batch, - # target_reward.astype('float32'), - # target_value.astype('float32'), target_policy, weights - # ] - - # [mask_batch, target_reward, target_value, target_policy, - # weights] = to_torch_float_tensor(data_list, self._cfg.device) - - # target_reward = target_reward.view(self._cfg.batch_size, -1) - # target_value = target_value.view(self._cfg.batch_size, -1) - - # assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) - - - batch_for_gpt = {} - # TODO: for cartpole self._cfg.model.observation_shape - if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape)==1: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - elif len(self._cfg.model.observation_shape)==3: - batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( self._cfg.batch_size, -1, *self._cfg.model.observation_shape) # (B, T, O) or (B, T, C, H, W) - - batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] # (B, T-1, O) or (B, T-1, C, H, W) - - # if train_which_component_dict['train_which_component'] == 'tokenizer': - - # ============================================================== - # update tokenizer - # ============================================================== - # TODO: train tokenlizer - self._learn_model.tokenizer.train() - - # for name, param in self._learn_model.tokenizer.named_parameters(): - # if param.requires_grad: - # print(name, param.shape) - - losses_tokenizer = self._learn_model.tokenizer.compute_loss(batch_for_gpt) - - self._optimizer_tokenizer.zero_grad() - - weighted_total_loss_tokenizer = losses_tokenizer.loss_total - weighted_total_loss_tokenizer.backward() - # losses_tokenizer.loss_total.backward() - - total_grad_norm_before_clip_tokenizer = torch.nn.utils.clip_grad_norm_( - self._learn_model.tokenizer.parameters(), self._cfg.grad_clip_value - ) - - - self._optimizer_tokenizer.step() - - intermediate_losses_tokenizer= defaultdict(float) - for loss_name, loss_value in losses_tokenizer.intermediate_losses.items(): - intermediate_losses_tokenizer[f"{loss_name}"] = loss_value - # print(intermediate_losses) - commitment_loss= intermediate_losses_tokenizer['commitment_loss'] - reconstruction_loss = intermediate_losses_tokenizer['reconstruction_loss'] - perceptual_loss = intermediate_losses_tokenizer['perceptual_loss'] - - - # # ============================================================== - # # the core target model update step. - # # ============================================================== - # self._target_model.update(self._learn_model.state_dict()) - # if self._cfg.use_rnd_model: - # self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict()) - - - return_loss_dict = { - 'collect_mcts_temperature': self._collect_mcts_temperature, - 'collect_epsilon': self.collect_epsilon, - 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], - 'cur_lr_tokenizer': self._optimizer_tokenizer.param_groups[0]['lr'], - - # 'weighted_total_loss': weighted_total_loss.item(), - # 'obs_loss': obs_loss, - # 'policy_loss': policy_loss, - # 'target_policy_entropy': average_target_policy_entropy, - # 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), - # 'reward_loss': reward_loss, - # 'value_loss': value_loss, - # 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, - - # ============================================================== - # priority related - # ============================================================== - # 'value_priority_orig': value_priority, - # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO - # 'value_priority': value_priority.mean().item(), - # 'target_reward': target_reward.detach().cpu().numpy().mean().item(), - # 'target_value': target_value.detach().cpu().numpy().mean().item(), - # 'transformed_target_reward': transformed_target_reward.detach().cpu().numpy().mean().item(), - # 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), - # 'predicted_rewards': predicted_rewards.detach().cpu().numpy().mean().item(), - # 'predicted_values': predicted_values.detach().cpu().numpy().mean().item(), - 'total_grad_norm_before_clip_tokenizer': total_grad_norm_before_clip_tokenizer.item(), - 'commitment_loss':commitment_loss, - 'reconstruction_loss':reconstruction_loss, - 'perceptual_loss': perceptual_loss, - } - - return return_loss_dict - - def _init_collect(self) -> None: - """ - Overview: - Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. - """ - self._collect_model = self._model - - if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) - else: - self._mcts_collect = MCTSPtree(self._cfg) - self._collect_mcts_temperature = 1. - self.collect_epsilon = 0.0 - self.last_batch_obs = torch.zeros([8,self._cfg.model.observation_shape[0],64,64]).to(self._cfg.device) - self.last_batch_action = [-1 for i in range(8)] - - def _forward_collect( - self, - data: torch.Tensor, - action_mask: list = None, - temperature: float = 1, - to_play: List = [-1], - epsilon: float = 0.25, - ready_env_id: np.array = None, - ) -> Dict: - """ - Overview: - The forward function for collecting data in collect mode. Use model to execute MCTS search. - Choosing the action through sampling during the collect mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - temperature (:obj:`float`): The temperature of the policy. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - temperature: :math:`(1, )`. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._collect_model.eval() - self._collect_model.tokenizer.eval() # TODO - self._collect_model.world_model.transformer.eval() # TODO - - - self._collect_mcts_temperature = temperature - self.collect_epsilon = epsilon - active_collect_env_num = data.shape[0] - # if active_collect_env_num == 1: - # print('debug') - with torch.no_grad(): - - network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data) - - # network_output = self._collect_model.initial_inference(data) - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() - latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() - - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] - # the only difference between collect and eval is the dirichlet noise - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(active_collect_env_num, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(active_collect_env_num, legal_actions) - - roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) - self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) - - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` - roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} - - data_id = [i for i in range(active_collect_env_num)] - output = {i: None for i in data_id} - - if ready_env_id is None: - ready_env_id = np.arange(active_collect_env_num) - - batch_action = [] - for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - if self._cfg.eps.eps_greedy_exploration_in_collect: - # eps greedy collect - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=True - ) - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - if np.random.rand() < self.collect_epsilon: - action = np.random.choice(legal_actions[i]) - else: - # normal collect - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self._collect_mcts_temperature, deterministic=False - ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - } - batch_action.append(action) - - self.last_batch_obs = data - self.last_batch_action = batch_action - - return output - - - def _init_eval(self) -> None: - """ - Overview: - Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. - """ - self._eval_model = self._model - if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) - else: - self._mcts_eval = MCTSPtree(self._cfg) - - def _get_target_obs_index_in_step_k(self, step): - """ - Overview: - Get the begin index and end index of the target obs in step k. - Arguments: - - step (:obj:`int`): The current step k. - Returns: - - beg_index (:obj:`int`): The begin index of the target obs in step k. - - end_index (:obj:`int`): The end index of the target obs in step k. - Examples: - >>> self._cfg.model.model_type = 'conv' - >>> self._cfg.model.image_channel = 3 - >>> self._cfg.model.frame_stack_num = 4 - >>> self._get_target_obs_index_in_step_k(0) - >>> (0, 12) - """ - if self._cfg.model.model_type == 'conv': - beg_index = self._cfg.model.image_channel * step - end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) - elif self._cfg.model.model_type == 'mlp': - beg_index = self._cfg.model.observation_shape * step - end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) - return beg_index, end_index - - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict: - """ - Overview: - The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. - Choosing the action with the highest value (argmax) rather than sampling during the eval mode. - Arguments: - - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - Shape: - - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ - S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None - Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. - """ - self._eval_model.eval() - self._eval_model.tokenizer.eval() # TODO - self._eval_model.world_model.transformer.eval() # TODO - - active_eval_env_num = data.shape[0] - with torch.no_grad(): - # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} - # network_output = self._collect_model.initial_inference(data) - network_output = self._eval_model.initial_inference(data) - latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) - - if not self._eval_model.training: - # if not in training, obtain the scalars of the value/reward - pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) - latent_state_roots = latent_state_roots.detach().cpu().numpy() - policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) - - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(active_eval_env_num, legal_actions) - else: - # python mcts_tree - roots = MCTSPtree.roots(active_eval_env_num, legal_actions) - roots.prepare_no_noise(reward_roots, policy_logits, to_play) - self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) - - # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` - roots_visit_count_distributions = roots.get_distributions() - roots_values = roots.get_values() # shape: {list: batch_size} - - data_id = [i for i in range(active_eval_env_num)] - output = {i: None for i in data_id} - - if ready_env_id is None: - ready_env_id = np.arange(active_eval_env_num) - - for i, env_id in enumerate(ready_env_id): - distributions, value = roots_visit_count_distributions[i], roots_values[i] - # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents - # the index within the legal action set, rather than the index in the entire action set. - # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than - # sampling during the evaluation phase. - - action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=1, deterministic=True - ) - # TODO: eval - # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - # distributions, temperature=self._collect_mcts_temperature, deterministic=False - # ) - # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the - # entire action set. - action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - - output[env_id] = { - 'action': action, - 'visit_count_distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'searched_value': value, - 'predicted_value': pred_values[i], - 'predicted_policy_logits': policy_logits[i], - } - - return output - - def _monitor_vars_learn(self) -> List[str]: - """ - Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in - tensorboard according to the return value ``_forward_learn``. - """ - return [ - 'Current_GPU', - 'Max_GPU', - 'collect_epsilon', - 'collect_mcts_temperature', - # 'cur_lr', - 'cur_lr_world_model', - 'cur_lr_tokenizer', - - 'weighted_total_loss', - # 'total_loss', - 'obs_loss', - 'policy_loss', - 'latent_kl_loss', - 'latent_recon_loss', - # 'policy_entropy', - 'target_policy_entropy', - 'reward_loss', - 'value_loss', - 'consistency_loss', - 'value_priority', - 'target_reward', - 'target_value', - # 'predicted_rewards', - # 'predicted_values', - # 'transformed_target_reward', - # 'transformed_target_value', - 'total_grad_norm_before_clip_tokenizer', - 'total_grad_norm_before_clip_wm', - 'total_grad_norm_before_clip_rep_net', - # tokenizer - 'commitment_loss', - 'reconstruction_loss', - 'perceptual_loss', - ] - - def _state_dict_learn(self) -> Dict[str, Any]: - """ - Overview: - Return the state_dict of learn mode, usually including model, target_model and optimizer. - Returns: - - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. - """ - return { - 'model': self._learn_model.state_dict(), - 'target_model': self._target_model.state_dict(), - 'optimizer_world_model': self._optimizer_world_model.state_dict(), - 'optimizer_tokenizer': self._optimizer_tokenizer.state_dict(), - } - - # TODO: - def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - """ - Overview: - Load the state_dict variable into policy learn mode. - Arguments: - - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. - """ - self._learn_model.load_state_dict(state_dict['model']) - self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) - self._optimizer_tokenizer.load_state_dict(state_dict['optimizer_tokenizer']) - - # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - # """ - # Overview: - # Load the state_dict variable into policy learn mode, specifically loading only the - # representation network of the tokenizer within model and target_model. - # Arguments: - # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. - # """ - # # Extract the relevant sub-state-dicts for representation_network from the state_dict - # # model_rep_network_state = state_dict['model']['tokenizer']['representation_network'] - # # target_model_rep_network_state = state_dict['target_model']['tokenizer']['representation_network'] - - # # # Load the state into the model's representation network - # # self._learn_model.tokenizer.representation_network.load_state_dict(model_rep_network_state) - # # self._target_model.tokenizer.representation_network.load_state_dict(target_model_rep_network_state) - - # # Assuming self._learn_model and self._target_model have a 'representation_network' submodule - # self._load_representation_network_state(state_dict['model'], self._learn_model.tokenizer.representation_network) - # self._load_representation_network_state(state_dict['target_model'], self._target_model.tokenizer.representation_network) - - - def _load_representation_network_state(self, state_dict, model_submodule): - """ - This function filters the state_dict to only include the state of the representation_network - and loads it into the given model submodule. - """ - from collections import OrderedDict - - # Filter the state_dict to only include keys that start with 'representation_network' - representation_network_keys = {k: v for k, v in state_dict.items() if k.startswith('representation_network')} - - # Load the state into the model's representation_network submodule - # model_submodule.load_state_dict(OrderedDict(representation_network_keys)) - - # 去掉键名前缀 - new_state_dict = OrderedDict() - for key, value in representation_network_keys.items(): - new_key = key.replace('representation_network.', '') # 去掉前缀 - new_state_dict[new_key] = value - - # # 如果模型在特定的设备上,确保状态字典也在那个设备上 - # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # new_state_dict = {key: value.to(device) for key, value in new_state_dict.items()} - - # 尝试加载状态字典 - try: - # model_submodule.load_state_dict(new_state_dict) - # 使用 strict=False 参数忽略缺少的键 - model_submodule.load_state_dict(new_state_dict, strict=False) - except RuntimeError as e: - print("加载失败: ", e) - - - def _process_transition(self, obs, policy_output, timestep): - # be compatible with DI-engine Policy class - pass - - def _get_train_sample(self, data): - # be compatible with DI-engine Policy class - pass diff --git a/zoo/atari/config/atari_xzero_config_stack1.py b/zoo/atari/config/atari_xzero_config_stack1.py index 209ec1f3c..2f6d0f873 100644 --- a/zoo/atari/config/atari_xzero_config_stack1.py +++ b/zoo/atari/config/atari_xzero_config_stack1.py @@ -1,13 +1,15 @@ from easydict import EasyDict import torch -torch.cuda.set_device(0) +torch.cuda.set_device(4) + +# ==== NOTE: 需要设置cfg_atari中的action_shape ===== # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} -env_name = 'PongNoFrameskip-v4' +# env_name = 'PongNoFrameskip-v4' # env_name = 'MsPacmanNoFrameskip-v4' # env_name = 'QbertNoFrameskip-v4' # env_name = 'SeaquestNoFrameskip-v4' -# env_name = 'BreakoutNoFrameskip-v4' # collect_env_steps=5e3 +env_name = 'BreakoutNoFrameskip-v4' # collect_env_steps=5e3 # env_name = 'BoxingNoFrameskip-v4' # env_name = 'FrostbiteNoFrameskip-v4' @@ -34,7 +36,7 @@ collector_env_num = 8 n_episode = 8 evaluator_env_num = 3 -# update_per_collect = 1000 # for pong boxing +update_per_collect = 1000 # for pong boxing update_per_collect = None # for others model_update_ratio = 0.25 @@ -62,7 +64,7 @@ # atari env action space # game_buffer_muzero_gpt task_id # TODO: muzero_gpt_model.py world_model.py (3,64,64) - exp_name=f'data_xzero_atari_0316/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-groupkl_fixed-act-emb_nogradscale_seed0', + exp_name=f'data_xzero_atari_0316/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-groupkl_fixed-act-emb_nogradscale_seed0_after-merge-memory', # exp_name=f'data_xzero_0312/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-groupkl_nogradscale_seed0', # exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_simnorm_latentw10_pew1e-4_seed0', @@ -86,8 +88,8 @@ # collect_max_episode_steps=int(50), # eval_max_episode_steps=int(50), # TODO: run - # collect_max_episode_steps=int(5e3), # for breakout - collect_max_episode_steps=int(2e4), # for others + collect_max_episode_steps=int(5e3), # for breakout + # collect_max_episode_steps=int(2e4), # for others eval_max_episode_steps=int(1e4), # eval_max_episode_steps=int(108000), clip_rewards=True, diff --git a/zoo/atari/config/atari_xzero_config_stack1_debug.py b/zoo/atari/config/atari_xzero_config_stack1_debug.py index 780dc05bd..5a36262f2 100644 --- a/zoo/atari/config/atari_xzero_config_stack1_debug.py +++ b/zoo/atari/config/atari_xzero_config_stack1_debug.py @@ -11,7 +11,7 @@ # env_name = 'BoxingNoFrameskip-v4' # env_name = 'FrostbiteNoFrameskip-v4' -# NOTE: 需要设置cfg_atari中的action_shape +# ==== NOTE: 需要设置cfg_atari中的action_shape ===== if env_name == 'PongNoFrameskip-v4': action_space_size = 6 elif env_name == 'QbertNoFrameskip-v4':