Skip to content

Commit

Permalink
fix(pu): fix memory_eval.py
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Mar 21, 2024
1 parent eb850d2 commit bf26548
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 61 deletions.
11 changes: 9 additions & 2 deletions lzero/entry/eval_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner
from lzero.worker import MuZeroEvaluator

from lzero.entry.train_muzero_gpt import initialize_zeros_batch

def eval_muzero(
input_cfg: Tuple[dict, dict],
Expand All @@ -38,7 +38,7 @@ def eval_muzero(
- policy (:obj:`Policy`): Converged policy.
"""
cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_gpt', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'"

if cfg.policy.cuda and torch.cuda.is_available():
Expand Down Expand Up @@ -85,6 +85,13 @@ def eval_muzero(
# Learner's before_run hook.
learner.call_hook('before_run')

policy.last_batch_obs = initialize_zeros_batch(
cfg.policy.model.observation_shape,
len(evaluator_env_cfg),
cfg.policy.device
)
policy.last_batch_action = [-1 for _ in range(len(evaluator_env_cfg))]

while True:
# ==============================================================
# eval trained model
Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ def search(
MCTS stage 3: Backup
At the end of the simulation, the statistics along the trajectory are updated.
"""
network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero
# network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero
# network_output = model.recurrent_inference(last_actions) # TODO: for muzero_gpt latent_states is not used in the model.
# network_output = model.recurrent_inference(state_action_history) # TODO: latent_states is not used in the model.
network_output = model.recurrent_inference(state_action_history) # TODO: latent_states is not used in the model.

network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
Expand Down
10 changes: 5 additions & 5 deletions lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@

'attention': 'causal',

# 'num_layers': 2, # TODO:for atari debug
'num_layers': 4, # TODO:for atari debug
'num_layers': 2, # TODO:for atari debug
# 'num_layers': 4, # TODO:for atari debug
# 'num_layers': 6, # TODO:for atari debug
# 'num_layers': 12, # TODO:for atari debug
'num_heads': 8,

'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:4',
"device": 'cuda:7',
# "device": 'cpu',
'support_size': 601,

Expand Down Expand Up @@ -79,8 +79,8 @@
# 'policy_entropy_weight': 0,
'policy_entropy_weight': 1e-4,

# 'predict_latent_loss_type': 'group_kl',
'predict_latent_loss_type': 'mse',
'predict_latent_loss_type': 'group_kl',
# 'predict_latent_loss_type': 'mse',
'obs_type': 'image', # 'vector', 'image'

}
Expand Down
26 changes: 17 additions & 9 deletions lzero/model/gpt_models/cfg_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,25 @@
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0}} # TODO:for atari debug
cfg['world_model'] = {
# 'tokens_per_block': 2, # memory_length = 2
'tokens_per_block': 2,

# 'max_blocks': 32,
# "max_tokens": 2 * 32, # TODO: horizon
# "max_tokens": 2 * 32, # memory_length = 2

'max_blocks': 60, # memory_length = 30
"max_tokens": 2 * 60,

# 'max_blocks': 80, # memory_length = 50
# "max_tokens": 2 * 80,

# 'max_blocks': 130, # memory_length = 100
# "max_tokens": 2 * 130,

# 'tokens_per_block': 2, # memory_length = 30
# 'max_blocks': 60,
# "max_tokens": 2 * 60, # TODO: horizon
# 'max_blocks': 280, # memory_length = 250
# "max_tokens": 2 * 280,

'tokens_per_block': 2, # memory_length = 50
'max_blocks': 80,
"max_tokens": 2 * 80, # TODO: horizon
# 'max_blocks': 530, # memory_length = 250
# "max_tokens": 2 * 530,

'embed_dim': 64, # TODO:for memory # same as <Transformer shine in RL> paper
'group_size': 8, # NOTE
Expand All @@ -34,7 +42,7 @@
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:6',
"device": 'cuda:0',
'support_size': 21,
'action_shape': 4, # NOTE:for memory
'max_cache_size': 5000,
Expand Down
17 changes: 9 additions & 8 deletions lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,14 +522,15 @@ def collect(self,

eps_steps_lst[env_id] += 1

# if eps_steps_lst[env_id] % 200 == 0:
# self._policy.get_attribute('collect_model').world_model.past_keys_values_cache.clear()
# self._policy.get_attribute('collect_model').world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad
# # self._policy._learn_model.world_model.past_keys_values_cache.clear() # very important
# # del self._policy.get_attribute('collect_model').world_model.keys_values_wm
# torch.cuda.empty_cache() # TODO: NOTE
# print('collector: collect_model clear()')
# print(f'eps_steps_lst[{env_id}]:{eps_steps_lst[env_id]}')
if eps_steps_lst[env_id] % 200 == 0: # TODO: NOTE
# if eps_steps_lst[env_id] % 20 == 0:
self._policy.get_attribute('collect_model').world_model.past_keys_values_cache.clear()
self._policy.get_attribute('collect_model').world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad
# self._policy._learn_model.world_model.past_keys_values_cache.clear() # very important
# del self._policy.get_attribute('collect_model').world_model.keys_values_wm
torch.cuda.empty_cache()
print('collector: collect_model clear()')
print(f'eps_steps_lst[{env_id}]:{eps_steps_lst[env_id]}')

total_transitions += 1

Expand Down
13 changes: 7 additions & 6 deletions lzero/worker/muzero_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,13 @@ def eval(
obs, reward, done, info = t.obs, t.reward, t.done, t.info

eps_steps_lst[env_id] += 1
# if eps_steps_lst[env_id] % 200 == 0:
# self._policy.get_attribute('eval_model').world_model.past_keys_values_cache.clear()
# self._policy.get_attribute('eval_model').world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad
# torch.cuda.empty_cache() # TODO: NOTE
# print('evaluator: eval_model clear()')
# print(f'eps_steps_lst[{env_id}]:{eps_steps_lst[env_id]}')
if eps_steps_lst[env_id] % 200 == 0: # TODO: NOTE
# if eps_steps_lst[env_id] % 20 == 0:
self._policy.get_attribute('eval_model').world_model.past_keys_values_cache.clear()
self._policy.get_attribute('eval_model').world_model.keys_values_wm_list.clear() # TODO: 只适用于recurrent_inference() batch_pad
torch.cuda.empty_cache()
print('evaluator: eval_model clear()')
print(f'eps_steps_lst[{env_id}]:{eps_steps_lst[env_id]}')


game_segments[env_id].append(
Expand Down
19 changes: 11 additions & 8 deletions zoo/atari/config/atari_xzero_config_stack1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from easydict import EasyDict
import torch
torch.cuda.set_device(4)
torch.cuda.set_device(7)

# ==== NOTE: 需要设置cfg_atari中的action_shape =====

Expand Down Expand Up @@ -46,14 +46,17 @@

num_simulations = 50

max_env_step = int(1.5e6)
# max_env_step = int(1.5e6)
max_env_step = int(2e6)

reanalyze_ratio = 0.
# reanalyze_ratio = 0.05 # TODO

batch_size = 64
num_unroll_steps = 5
# num_unroll_steps = 10

threshold_training_steps_for_final_temperature = int(2e5) # train_iter 100k 1->0.5->0.25
# eps_greedy_exploration_in_collect = True # for breakout
eps_greedy_exploration_in_collect = False
# ==============================================================
Expand All @@ -68,7 +71,7 @@
# atari env action space
# game_buffer_muzero_gpt task_id
# TODO: muzero_gpt_model.py world_model.py (3,64,64)
exp_name=f'data_xzero_atari_0318/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-mse_no-aug_no-priority_nlayer4_seed0',
exp_name=f'data_xzero_atari_0321/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-groupkl_use-aug_no-priority_nlayer2_temp-final-steps-{threshold_training_steps_for_final_temperature}_seed0',

# exp_name=f'data_xzero_atari_0316/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-mse_learned-act-emb_nogradscale_seed0_after-merge-memory_priority',
# exp_name=f'data_xzero_atari_0316/{env_name[:-14]}_xzero_envnum{collector_env_num}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_stack1_mcts-kvbatch-pad-min-quantize15-lsd768-nh8_simnorm_latentw10_pew1e-4_latent-mse_learned-act-emb_nogradscale_seed0_after-merge-memory_useaug',
Expand Down Expand Up @@ -150,8 +153,8 @@
),
use_priority=False,
# use_priority=True, # NOTE
use_augmentation=False, # NOTE
# use_augmentation=True, # NOTE: only for image-based atari
# use_augmentation=False, # NOTE
use_augmentation=True, # NOTE: only for image-based atari
cuda=True,
env_type='not_board_games',
game_segment_length=400,
Expand All @@ -174,8 +177,8 @@
# lr_piecewise_constant_decay=True,
# learning_rate=0.2,

# manual_temperature_decay=True,
# threshold_training_steps_for_final_temperature=int(5e4), # 100k 1->0.5->0.25
manual_temperature_decay=True,
threshold_training_steps_for_final_temperature=threshold_training_steps_for_final_temperature,

optim_type='Adam',
lr_piecewise_constant_decay=False,
Expand All @@ -186,7 +189,7 @@
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
# eval_freq=int(9e9),
eval_freq=int(1e4),
eval_freq=int(2e4),
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
Expand Down
25 changes: 17 additions & 8 deletions zoo/memory/config/memory_muzero_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from easydict import EasyDict
import torch
torch.cuda.set_device(7)
torch.cuda.set_device(4)

env_id = 'visual_match' # The name of the environment, options: 'visual_match', 'key_to_door'
memory_length = 2
# env_id = 'key_to_door' # The name of the environment, options: 'visual_match', 'key_to_door'

memory_length = 250
# to_test [2, 30, 50, 100]
# hard [250, 500, 750, 1000]

Expand All @@ -15,14 +17,14 @@
seed = 0
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
evaluator_env_num = 10
num_simulations = 50
update_per_collect = None # for others
model_update_ratio = 0.25
batch_size = 256
reanalyze_ratio = 0
td_steps = 5
num_unroll_steps = 30+memory_length
game_segment_length = 30+memory_length

# debug
# collector_env_num = 1
Expand All @@ -42,9 +44,9 @@

memory_muzero_config = dict(
# mcts_ctree.py muzero_collector muzero_evaluator
exp_name=f'data_memory/{env_id}_memlen-{memory_length}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_'
exp_name=f'data_memory_{env_id}/memlen-{memory_length}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_'
f'collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}'
f'_pelw{policy_entropy_loss_weight}_seed{seed}',
f'_pelw{policy_entropy_loss_weight}_seed{seed}_evalnum{evaluator_env_num}',
env=dict(
stop_value=int(1e6),
env_id=env_id,
Expand All @@ -60,6 +62,13 @@
manager=dict(shared_memory=False, ),
),
policy=dict(
learner=dict(
hook=dict(
log_show_after_iter=200,
save_ckpt_after_iter=100000, # TODO: default:10000
save_ckpt_after_run=True,
),
),
model=dict(
observation_shape=25,
action_space_size=4,
Expand All @@ -79,7 +88,7 @@
threshold_training_steps_for_final_temperature=threshold_training_steps_for_final_temperature,
cuda=True,
env_type='not_board_games',
game_segment_length=num_unroll_steps,
game_segment_length=game_segment_length,
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
Expand All @@ -89,7 +98,7 @@
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(2e3),
eval_freq=int(5e3),
replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
Expand Down
13 changes: 8 additions & 5 deletions zoo/memory/config/memory_xzero_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from easydict import EasyDict
import torch
torch.cuda.set_device(6)
torch.cuda.set_device(0)

env_id = 'visual_match' # The name of the environment, options: 'visual_match', 'key_to_door'
memory_length = 50
# env_id = 'key_to_door' # The name of the environment, options: 'visual_match', 'key_to_door'

memory_length = 30
# to_test [2, 30, 50, 100]
# hard [250, 500, 750, 1000]

Expand All @@ -18,7 +20,7 @@
seed = 0
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
evaluator_env_num = 8
num_simulations = 50
update_per_collect = None # for others
model_update_ratio = 0.25
Expand Down Expand Up @@ -48,9 +50,9 @@

memory_xzero_config = dict(
# mcts_ctree.py muzero_collector muzero_evaluator
exp_name=f'data_memory/{env_id}_memlen-{memory_length}_xzero_H{num_unroll_steps}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}'
exp_name=f'data_memory_{env_id}_eval/memlen-{memory_length}_xzero_H{num_unroll_steps}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}'
f'_collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}'
f'_pelw1e-4_quan15_mse_emd64_seed{seed}',
f'_pelw1e-4_quan15_mse_emd64_seed{seed}_eval{evaluator_env_num}',
env=dict(
stop_value=int(1e6),
env_id=env_id,
Expand All @@ -75,6 +77,7 @@
),

model_path=None,
# model_path='/mnt/afs/niuyazhe/code/LightZero/data_memory_visual_match/memlen-2_xzero_H32_ns50_upcNone-mur0.25_rr0_H32_bs64_collect-eps-True_temp-final-steps-500000_pelw1e-4_quan15_mse_emd64_seed0_240320_190454/ckpt/ckpt_best.pth.tar',
transformer_start_after_envsteps=int(0),
update_per_collect_transformer=update_per_collect,
update_per_collect_tokenizer=update_per_collect,
Expand Down
4 changes: 2 additions & 2 deletions zoo/memory/config/memory_xzero_config_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
memory_length = 2 # to_test [2, 50, 100, 250, 500, 750, 1000]


max_env_step = int(10e6)
max_env_step = int(5e6)
# ==== NOTE: 需要设置cfg_memory中的action_shape =====
# ==== NOTE: 需要设置cfg_memory中的policy_entropy_weight =====

Expand All @@ -17,7 +17,7 @@
seed = 0
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
evaluator_env_num = 8
num_simulations = 50
update_per_collect = None # for others
model_update_ratio = 0.25
Expand Down
Loading

0 comments on commit bf26548

Please sign in to comment.