Skip to content

Commit

Permalink
polish(pu): use learned-act-embeddings, use latent mse loss, clear-pe…
Browse files Browse the repository at this point in the history
…r-200 in evaluator
  • Loading branch information
puyuan1996 committed Mar 17, 2024
1 parent f15c85a commit 670cfa3
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 37 deletions.
37 changes: 31 additions & 6 deletions lzero/entry/train_muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
from .utils import random_collect
import torch.nn as nn

def initialize_zeros_batch(observation_shape, batch_size, device):
"""Initialize a zeros tensor for batch observations based on the shape."""
if isinstance(observation_shape, list):
shape = [batch_size, *observation_shape]
elif isinstance(observation_shape, int):
shape = [batch_size, observation_shape]
else:
raise TypeError("observation_shape must be either an int or a list")

return torch.zeros(shape).to(device)


def train_muzero_gpt(
input_cfg: Tuple[dict, dict],
Expand Down Expand Up @@ -133,8 +144,12 @@ def train_muzero_gpt(
num_unroll_steps = copy.deepcopy(replay_buffer._cfg.num_unroll_steps)
collect_cnt = -1

# policy.last_batch_obs = torch.zeros([len(evaluator_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device)
policy.last_batch_obs = torch.zeros([len(evaluator_env_cfg), cfg.policy.model.observation_shape]).to(cfg.policy.device)
# Usage
policy.last_batch_obs = initialize_zeros_batch(
cfg.policy.model.observation_shape,
len(evaluator_env_cfg),
cfg.policy.device
)
policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))]
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)

Expand Down Expand Up @@ -165,15 +180,21 @@ def train_muzero_gpt(

# Evaluate policy performance.
if evaluator.should_eval(learner.train_iter):
# policy.last_batch_obs = torch.zeros([len(evaluator_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device)
policy.last_batch_obs = torch.zeros([len(evaluator_env_cfg), cfg.policy.model.observation_shape]).to(cfg.policy.device)
policy.last_batch_obs = initialize_zeros_batch(
cfg.policy.model.observation_shape,
len(evaluator_env_cfg),
cfg.policy.device
)
policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))]
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break

policy.last_batch_obs = torch.zeros([len(collector_env_cfg), cfg.policy.model.observation_shape]).to(cfg.policy.device)
# policy.last_batch_obs = torch.zeros([len(collector_env_cfg), cfg.policy.model.observation_shape[0], 64, 64]).to(cfg.policy.device)
policy.last_batch_obs = initialize_zeros_batch(
cfg.policy.model.observation_shape,
len(collector_env_cfg),
cfg.policy.device
)
policy.last_batch_action = [-1 for _ in range(len(collector_env_cfg))]
# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
Expand Down Expand Up @@ -228,6 +249,10 @@ def train_muzero_gpt(
policy._collect_model.world_model.past_keys_values_cache.clear() # very important
policy._collect_model.world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad

# policy._eval_model.world_model.past_keys_values_cache.clear() # very important
# policy._eval_model.world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad


torch.cuda.empty_cache() # TODO: NOTE


Expand Down
20 changes: 8 additions & 12 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,25 +64,20 @@
# 'embed_dim':256, # TODO:for atari
# 'embed_dim':1024, # TODO:for atari
'embed_dim':768, # TODO:for atari
'group_size': 8, # NOTE

'attention': 'causal',
# 'num_layers': 10,# TODO:for atari
# 'num_layers': 2, # TODO:for atari debug
# 'num_heads': 4,
# 'num_layers': 1, # TODO:for atari debug
# 'num_heads': 1,

'num_layers': 2, # TODO:for atari debug
# 'num_heads': 4,
# 'num_layers': 2, # TODO:for atari debug
# 'num_layers': 6, # TODO:for atari debug
'num_layers': 12, # TODO:for atari debug
'num_heads': 8,


'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:0',
"device": 'cuda:2',
# "device": 'cpu',
# 'support_size': 21,
'support_size': 601,

# 'action_shape': 18,# TODO:for multi-task
Expand All @@ -109,8 +104,9 @@
# 'policy_entropy_weight': 0,
'policy_entropy_weight': 1e-4,

'predict_latent_loss_type': 'group_kl', # 'mse'
# 'predict_latent_loss_type': 'mse', # 'mse'
# 'predict_latent_loss_type': 'group_kl',
'predict_latent_loss_type': 'mse',
'obs_type': 'image', # 'vector', 'image'

}
from easydict import EasyDict
Expand Down
11 changes: 5 additions & 6 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
)

self.act_embedding_table = nn.Embedding(act_vocab_size, config.embed_dim)
# NOTE: 对于离散动作,使用fixed_act_embedding,效率更高, 注意需要self.act_embedding_table.weight不是全零初始化的 ####
self.act_embedding_table.weight.requires_grad = False
# NOTE: 对于离散动作,使用fixed_act_embedding,可能前期效率更高, 注意需要self.act_embedding_table.weight不是全零初始化的 ####
# self.act_embedding_table.weight.requires_grad = False

self.obs_per_embdding_dim = config.embed_dim # 16*64=1024

Expand Down Expand Up @@ -396,9 +396,9 @@ def refresh_keys_values_with_initial_latent_state_for_init_infer(self, latent_st
print('root_total_query_cnt:', self.root_total_query_cnt)
print(f'root_hit_ratio:{root_hit_ratio}')
print(f'root_hit find size {self.past_keys_values_cache[cache_key].size}')
if self.past_keys_values_cache[cache_key].size >= 7:
if self.past_keys_values_cache[cache_key].size >= self.config.max_tokens - 5:
print(f'==' * 20)
print(f'NOTE: root_hit find size >= 7')
print(f'NOTE: root_hit find size >= self.config.max_tokens - 5')
print(f'==' * 20)
# 这里需要deepcopy因为在transformer的forward中会原地修改matched_value
self.keys_values_wm_list.append(copy.deepcopy(self.to_device_for_kvcache(matched_value, 'cuda')))
Expand Down Expand Up @@ -744,12 +744,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer=None, **kwargs: Any) -
# 计算观察的预测损失。这里提供了两种选择:MSE和Group KL
if self.predict_latent_loss_type == 'mse':
# MSE损失,直接比较logits和labels
loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations.detach(), reduction='none').mean(-1)
loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations, reduction='none').mean(-1) # labels_observations.detach()是冗余的,因为前面是在with torch.no_grad()中计算的
elif self.predict_latent_loss_type == 'group_kl':
# Group KL损失,将特征分组,然后计算组内的KL散度
batch_size, num_features = logits_observations.shape


logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size)
labels_reshaped = labels_observations.reshape(batch_size, self.num_groups,self. group_size)

Expand Down
4 changes: 2 additions & 2 deletions lzero/model/muzero_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def __init__(
from .gpt_models.tokenizer.tokenizer import Tokenizer
from .gpt_models.tokenizer.nets import Encoder, Decoder
# from .gpt_models.cfg_cartpole import cfg
from .gpt_models.cfg_memory import cfg
# from .gpt_models.cfg_atari import cfg
# from .gpt_models.cfg_memory import cfg # NOTE: TODO
from .gpt_models.cfg_atari import cfg
if cfg.world_model.obs_type == 'vector':
self.representation_network = RepresentationNetworkMLP(
observation_shape,
Expand Down
14 changes: 13 additions & 1 deletion lzero/worker/muzero_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def eval(

ready_env_id = set()
remain_episode = n_episode

eps_steps_lst = np.zeros(env_nums)
with self._timer:
while not eval_monitor.is_finished():
# Get current ready env obs.
Expand Down Expand Up @@ -322,6 +322,15 @@ def eval(
for env_id, t in timesteps.items():
obs, reward, done, info = t.obs, t.reward, t.done, t.info

eps_steps_lst[env_id] += 1
if eps_steps_lst[env_id] % 200 == 0:
self._policy.get_attribute('eval_model').world_model.past_keys_values_cache.clear()
self._policy.get_attribute('eval_model').world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad
torch.cuda.empty_cache() # TODO: NOTE
print('evaluator: eval_model clear()')
print(f'eps_steps_lst[{env_id}]:{eps_steps_lst[env_id]}')


game_segments[env_id].append(
actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id],
to_play_dict[env_id]
Expand Down Expand Up @@ -394,6 +403,9 @@ def eval(
]
)


eps_steps_lst[env_id] = 0

# Env reset is done by env_manager automatically.
self._policy.reset([env_id])
# TODO(pu): subprocess mode, when n_episode > self._env_num, occasionally the ready_env_id=()
Expand Down
21 changes: 13 additions & 8 deletions zoo/atari/config/atari_xzero_config_stack1.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from easydict import EasyDict
import torch
torch.cuda.set_device(4)
torch.cuda.set_device(2)

# ==== NOTE: 需要设置cfg_atari中的action_shape =====

# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...}
# env_name = 'PongNoFrameskip-v4'
env_name = 'PongNoFrameskip-v4'
# env_name = 'MsPacmanNoFrameskip-v4'
# env_name = 'QbertNoFrameskip-v4'
# env_name = 'SeaquestNoFrameskip-v4'
env_name = 'BreakoutNoFrameskip-v4' # collect_env_steps=5e3
# env_name = 'BreakoutNoFrameskip-v4' # collect_env_steps=5e3
# env_name = 'BoxingNoFrameskip-v4'
# env_name = 'FrostbiteNoFrameskip-v4'

Expand All @@ -36,10 +36,12 @@
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
update_per_collect = 1000 # for pong boxing
# update_per_collect = 1000 # for pong boxing
update_per_collect = None # for others

model_update_ratio = 0.25
# model_update_ratio = 0.25 # for others
model_update_ratio = 0.125 # for pong boxing

num_simulations = 50

max_env_step = int(2e6)
Expand All @@ -64,7 +66,8 @@
# atari env action space
# game_buffer_muzero_gpt task_id
# TODO: muzero_gpt_model.py world_model.py (3,64,64)
exp_name=f'data_xzero_atari_0316/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-groupkl_fixed-act-emb_nogradscale_seed0_after-merge-memory',
exp_name=f'data_xzero_atari_0316/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-mse_learned-act-emb_nogradscale_seed0_after-merge-memory_nlayer12',
# exp_name=f'data_xzero_atari_0316/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-mse_learned-act-emb_nogradscale_seed0_after-merge-memory_useaug',

# exp_name=f'data_xzero_0312/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-groupkl_nogradscale_seed0',
# exp_name=f'data_xzero_0307/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_new-rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kv-reset-5-kvbatch-pad-min-quantize15-lsd768-nh8_fixroot_simnorm_latentw10_pew1e-4_seed0',
Expand All @@ -88,8 +91,8 @@
# collect_max_episode_steps=int(50),
# eval_max_episode_steps=int(50),
# TODO: run
collect_max_episode_steps=int(5e3), # for breakout
# collect_max_episode_steps=int(2e4), # for others
# collect_max_episode_steps=int(5e3), # for breakout
collect_max_episode_steps=int(2e4), # for others
eval_max_episode_steps=int(1e4),
# eval_max_episode_steps=int(108000),
clip_rewards=True,
Expand Down Expand Up @@ -156,6 +159,8 @@
decay=int(1e4), # 10k
),
use_augmentation=False, # NOTE
# use_augmentation=True, # NOTE: only for image-based atari

update_per_collect=update_per_collect,
model_update_ratio = model_update_ratio,
batch_size=batch_size,
Expand Down
4 changes: 2 additions & 2 deletions zoo/atari/config/atari_xzero_config_stack1_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
eps_greedy_exploration_in_collect = False

# TODO: debug
num_simulations = 5
num_simulations = 2
update_per_collect = 5
batch_size = 2
# ==============================================================
Expand Down Expand Up @@ -183,7 +183,7 @@
),
# TODO: NOTE
# use_augmentation=True,
use_augmentation=False,
use_augmentation=True, # NOTE
update_per_collect=update_per_collect,
model_update_ratio = model_update_ratio,
batch_size=batch_size,
Expand Down

0 comments on commit 670cfa3

Please sign in to comment.