Skip to content

Commit

Permalink
polish(pu): polish world model, use latent-groupkl-loss, no latent gr…
Browse files Browse the repository at this point in the history
…ad_scale, fixed-act-embedding
  • Loading branch information
puyuan1996 committed Mar 16, 2024
1 parent a7d5d21 commit 1c13700
Show file tree
Hide file tree
Showing 10 changed files with 1,430 additions and 744 deletions.
14 changes: 6 additions & 8 deletions lzero/entry/train_muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,20 @@ def train_muzero_gpt(
else:
collect_kwargs['epsilon'] = 0.0

# policy.last_batch_obs = torch.zeros([len(evaluator_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device)
# policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))]
# stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)

# Evaluate policy performance.
if evaluator.should_eval(learner.train_iter):
policy.last_batch_obs = torch.zeros([len(evaluator_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device)
policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))]
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break

policy.last_batch_obs = torch.zeros([len(collector_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device)
policy.last_batch_action = [-1 for _ in range(len(collector_env_cfg))]
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
Expand Down Expand Up @@ -214,19 +220,11 @@ def train_muzero_gpt(
policy._target_model.world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad
print('sample target_model past_keys_values_cache.clear()')

# del policy._learn_model.world_model.keys_values_wm

# TODO: for batch world model ,to improve kv reuse, we can donot reset
# policy._learn_model.world_model.past_keys_values_cache.clear() # very important
# policy._eval_model.world_model.past_keys_values_cache.clear() # very important

policy._collect_model.world_model.past_keys_values_cache.clear() # very important
policy._collect_model.world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad

torch.cuda.empty_cache() # TODO: NOTE

policy.last_batch_obs = torch.zeros([len(collector_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device)
policy.last_batch_action = [-1 for _ in range(len(collector_env_cfg))]

# if collector.envstep > 0:
# # TODO: only for debug
Expand Down
43 changes: 0 additions & 43 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,25 +139,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

return output

# EZ original
# def renormalize(inputs: torch.Tensor, first_dim: int = 1) -> torch.Tensor:
# """
# Overview:
# Normalize the input data using the max-min-normalization.
# Arguments:
# - inputs (:obj:`torch.Tensor`): The input data needs to be normalized.
# - first_dim (:obj:`int`): The first dimension of flattening the input data.
# Returns:
# - output (:obj:`torch.Tensor`): The normalized data.
# """
# if first_dim < 0:
# first_dim = len(inputs.shape) + first_dim
# flat_input = inputs.view(*inputs.shape[:first_dim], -1)
# max_val = torch.max(flat_input, first_dim, keepdim=True).values
# min_val = torch.min(flat_input, first_dim, keepdim=True).values
# flat_input = (flat_input - min_val) / (max_val - min_val)

# return flat_input.view(*input.shape)

def renormalize_min_max(x): # min-max
# x is a 2D tensor of shape (batch_size, num_features)
Expand All @@ -171,30 +152,6 @@ def renormalize_min_max(x): # min-max

return x_scaled

# def renormalize(x): # z-score
# # x is a 2D tensor of shape (batch_size, num_features)
# # Compute the mean and standard deviation for each feature across the batch
# mean = torch.mean(x, dim=0, keepdim=True)
# std = torch.std(x, dim=0, keepdim=True)

# # Apply z-score normalization
# x_normalized = (x - mean) / (std + 1e-8) # Add a small epsilon to avoid division by zero

# return x_normalized

# def renormalize(x): # robust scaling
# # x is a 2D tensor of shape (batch_size, num_features)
# # Compute the 1st and 3rd quartile
# q1 = torch.quantile(x, 0.25, dim=0, keepdim=True)
# q3 = torch.quantile(x, 0.75, dim=0, keepdim=True)

# # Compute the interquartile range (IQR)
# iqr = q3 - q1

# # Apply robust scaling
# x_scaled = (x - q1) / (iqr + 1e-8) # Again, add epsilon to avoid division by zero

# return x_scaled

class SimNorm(nn.Module):
"""
Expand Down
16 changes: 10 additions & 6 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:0',
"device": 'cuda:3',
# "device": 'cpu',
# 'support_size': 21,
'support_size': 601,

# 'action_shape': 18,# TODO:for multi-task

# 'action_shape': 18,# TODO:for Seaquest boxing
# 'action_shape': 18,# TODO:for Seaquest boxing Frostbite
# 'action_shape': 9,# TODO:for mspacman
# 'action_shape': 4,# TODO:for breakout
'action_shape': 6,# TODO:for pong qbert
'action_shape': 4,# TODO:for breakout
# 'action_shape': 6,# TODO:for pong qbert

'max_cache_size':5000,
# 'max_cache_size':50000,
Expand All @@ -106,8 +106,12 @@

# 'latent_recon_loss_weight':0.,
# 'perceptual_loss_weight':0.,
'policy_entropy_weight': 0,
# 'policy_entropy_weight': 1e-4,

# 'policy_entropy_weight': 0,
'policy_entropy_weight': 1e-4,

'predict_latent_loss_type': 'group_kl', # 'mse'
# 'predict_latent_loss_type': 'mse', # 'mse'

}
from easydict import EasyDict
Expand Down
44 changes: 0 additions & 44 deletions lzero/model/gpt_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,6 @@
from .episode import Episode


# def configure_optimizer(model, learning_rate, weight_decay, *blacklist_module_names):
# """Credits to https://github.com/karpathy/minGPT"""
# # separate out all parameters to those that will and won't experience regularizing weight decay
# decay = set()
# no_decay = set()
# whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d)
# blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
# for mn, m in model.named_modules():
# for pn, p in m.named_parameters():
# fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
# if any([fpn.startswith(module_name) for module_name in blacklist_module_names]):
# no_decay.add(fpn)
# elif 'bias' in pn:
# # all biases will not be decayed
# no_decay.add(fpn)
# elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# # weights of whitelist modules will be weight decayed
# decay.add(fpn)
# elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# # weights of blacklist modules will NOT be weight decayed
# no_decay.add(fpn)

# # validate that we considered every parameter
# param_dict = {pn: p for pn, p in model.named_parameters()}
# inter_params = decay & no_decay
# union_params = decay | no_decay
# assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
# assert len(
# param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"

# # create the pytorch optimizer object
# optim_groups = [
# {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
# {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
# ]
# optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate)
# return optimizer

from lzero.model.common import RepresentationNetwork

def init_weights(module):
Expand Down Expand Up @@ -114,7 +76,6 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg
# # self.ends_loss_weight = 1.
# self.ends_loss_weight = 0.

# self.obs_loss_weight = 0.1
self.obs_loss_weight = 10
# self.obs_loss_weight = 2

Expand All @@ -125,9 +86,6 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg
# self.ends_loss_weight = 1.
self.ends_loss_weight = 0.

# self.latent_kl_loss_weight = 0.1 # for lunarlander
self.latent_kl_loss_weight = 0. # for lunarlander

self.latent_recon_loss_weight = latent_recon_loss_weight
self.perceptual_loss_weight = perceptual_loss_weight
# self.latent_recon_loss_weight = 0.1
Expand All @@ -146,8 +104,6 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg
self.loss_total += self.value_loss_weight * v
elif k == 'loss_ends':
self.loss_total += self.ends_loss_weight * v
elif k == 'latent_kl_loss':
self.loss_total += self.latent_kl_loss_weight * v
elif k == 'latent_recon_loss':
self.loss_total += self.latent_recon_loss_weight * v
elif k == 'perceptual_loss':
Expand Down
Loading

0 comments on commit 1c13700

Please sign in to comment.