Skip to content

Commit

Permalink
polish(pu): polish memory_env eval config
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Mar 22, 2024
1 parent 735f44d commit c3031b0
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 27 deletions.
10 changes: 5 additions & 5 deletions lzero/model/gpt_models/cfg_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
cfg['world_model'] = {
'tokens_per_block': 2,

# 'max_blocks': 32,
# "max_tokens": 2 * 32, # memory_length = 2
'max_blocks': 32,
"max_tokens": 2 * 32, # memory_length = 2

# 'max_blocks': 60, # memory_length = 30
# "max_tokens": 2 * 60,
Expand All @@ -28,8 +28,8 @@
# 'max_blocks': 280, # memory_length = 250
# "max_tokens": 2 * 280,

'max_blocks': 530, # memory_length = 500
"max_tokens": 2 * 530,
# 'max_blocks': 530, # memory_length = 500
# "max_tokens": 2 * 530,

# 'max_blocks': 780, # memory_length = 750
# "max_tokens": 2 * 780,
Expand All @@ -48,7 +48,7 @@
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:3',
"device": 'cuda:0',
'support_size': 21,
'action_shape': 4, # NOTE:for memory
'max_cache_size': 5000,
Expand Down
2 changes: 2 additions & 0 deletions lzero/policy/muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1

for i, env_id in enumerate(ready_env_id):
distributions, value = roots_visit_count_distributions[i], roots_values[i]
print("roots_visit_count_distributions:", distributions, "root_value:", value) # TODO

# NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
# the index within the legal action set, rather than the index in the entire action set.
# Setting deterministic=True implies choosing the action with the highest value (argmax) rather than
Expand Down
4 changes: 3 additions & 1 deletion zoo/memory/config/memory_muzero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
# batch_size = 2

policy_entropy_loss_weight = 1e-4
threshold_training_steps_for_final_temperature = int(5e5)
# threshold_training_steps_for_final_temperature = int(5e5)
threshold_training_steps_for_final_temperature = int(1e5)

# eps_greedy_exploration_in_collect = False
eps_greedy_exploration_in_collect = True
# ==============================================================
Expand Down
4 changes: 3 additions & 1 deletion zoo/memory/config/memory_xzero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
# update_per_collect = 2
# batch_size = 2

threshold_training_steps_for_final_temperature = int(5e5)
# threshold_training_steps_for_final_temperature = int(5e5)
threshold_training_steps_for_final_temperature = int(1e5)

# eps_greedy_exploration_in_collect = False
eps_greedy_exploration_in_collect = True
# ==============================================================
Expand Down
31 changes: 21 additions & 10 deletions zoo/memory/config/memory_xzero_config_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import torch
torch.cuda.set_device(0)

env_id = 'visual_match' # The name of the environment, options: 'visual_match', 'key_to_door'
# memory_length = 30
memory_length = 2 # to_test [2, 50, 100, 250, 500, 750, 1000]
# env_id = 'visual_match' # The name of the environment, options: 'visual_match', 'key_to_door'
env_id = 'key_to_door' # The name of the environment, options: 'visual_match', 'key_to_door'

memory_length = 2
# to_test [2, 30, 50, 100]
# hard [250, 500, 750, 1000]


max_env_step = int(5e6)
Expand All @@ -25,6 +28,7 @@
batch_size = 64
# num_unroll_steps = 5
num_unroll_steps = 30+memory_length
game_segment_length=30+memory_length # TODO:


reanalyze_ratio = 0
Expand All @@ -38,21 +42,27 @@
# update_per_collect = 2
# batch_size = 2

threshold_training_steps_for_final_temperature = int(5e5)
# threshold_training_steps_for_final_temperature = int(5e5)
threshold_training_steps_for_final_temperature = int(1e5)

# eps_greedy_exploration_in_collect = False
eps_greedy_exploration_in_collect = True
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

memory_xzero_config = dict(
exp_name=f'data_memory_debug/{env_id}_memlen-{memory_length}_xzero_H{num_unroll_steps}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}'
f'collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}'
f'_pelw1e-4_quan15_mse_seed{seed}',
# mcts_ctree.py muzero_collector muzero_evaluator
exp_name=f'data_memory_{env_id}_fixscale_debug/{env_id}_memlen-{memory_length}_xzero_H{num_unroll_steps}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_bs{batch_size}'
f'_collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}'
f'_pelw1e-4_quan15_mse_emd64_seed{seed}_eval{evaluator_env_num}_clearper20-notcache',
env=dict(
stop_value=int(1e6),
env_id=env_id,
flate_observation=True, # Whether to flatten the observation
# obs_max_scale=107, # Maximum value of the observation, for key_to_door
# obs_max_scale=101, # Maximum value of the observation, for visual_match
obs_max_scale=100,
max_frames={
"explore": 15,
"distractor": memory_length,
Expand All @@ -72,7 +82,9 @@
),
),

model_path=None,
# model_path=None,
model_path='/mnt/afs/niuyazhe/code/LightZero/data_memory_key_to_door_fixscale/key_to_door_memlen-2_xzero_H32_ns50_upcNone-mur0.25_rr0_bs64_collect-eps-True_temp-final-steps-500000_pelw1e-4_quan15_mse_emd64_seed0_eval8_clearper20-notcache/ckpt/ckpt_best.pth.tar',
# model_path='/mnt/afs/niuyazhe/code/LightZero/data_memory_visual_match/memlen-2_xzero_H32_ns50_upcNone-mur0.25_rr0_H32_bs64_collect-eps-True_temp-final-steps-500000_pelw1e-4_quan15_mse_emd64_seed0_240320_190454/ckpt/ckpt_best.pth.tar',
transformer_start_after_envsteps=int(0),
update_per_collect_transformer=update_per_collect,
update_per_collect_tokenizer=update_per_collect,
Expand All @@ -93,14 +105,13 @@
decay=int(2e5), # NOTE: TODO
),
use_priority=False,
# use_priority=True, # NOTE
use_augmentation=False, # NOTE
td_steps=td_steps,
manual_temperature_decay=True,
threshold_training_steps_for_final_temperature=threshold_training_steps_for_final_temperature,
cuda=True,
env_type='not_board_games',
game_segment_length=num_unroll_steps, # TODO:
game_segment_length=game_segment_length, # TODO:
update_per_collect=update_per_collect,
batch_size=batch_size,
lr_piecewise_constant_decay=False,
Expand Down
11 changes: 6 additions & 5 deletions zoo/memory/entry/memory_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from zoo.memory.config.memory_xzero_config import main_config, create_config
from zoo.memory.config.memory_xzero_config_debug import main_config, create_config
from lzero.entry import eval_muzero
import numpy as np

Expand All @@ -25,7 +25,8 @@
"""
# model_path = './ckpt/ckpt_best.pth.tar'
# model_path = None
model_path='/mnt/afs/niuyazhe/code/LightZero/data_memory_visual_match/memlen-30_xzero_H60_ns50_upcNone-mur0.25_rr0_H60_bs64_collect-eps-True_temp-final-steps-500000_pelw1e-4_quan15_mse_emd64_seed0/ckpt/ckpt_best.pth.tar'
# model_path='/mnt/afs/niuyazhe/code/LightZero/data_memory_visual_match_fixscale/memlen-2_muzero_ns50_upcNone_rr0_collect-eps-True_temp-final-steps-500000_pelw0.0001_seed0_evalnum10/ckpt/iteration_120000.pth.tar'
model_path='/mnt/afs/niuyazhe/code/LightZero/data_memory_key_to_door_fixscale/key_to_door_memlen-2_xzero_H32_ns50_upcNone-mur0.25_rr0_bs64_collect-eps-True_temp-final-steps-500000_pelw1e-4_quan15_mse_emd64_seed0_eval8_clearper20-notcache/ckpt/ckpt_best.pth.tar'


# Initialize a list with a single seed for the experiment
Expand All @@ -35,10 +36,10 @@
num_episodes_each_seed = 1

# Specify the number of environments for the evaluator to use
main_config.env.evaluator_env_num = 8
main_config.env.evaluator_env_num = 1

# Set the number of episodes for the evaluator to run
main_config.env.n_evaluator_episode = 8
main_config.env.n_evaluator_episode = 1

# The total number of test episodes is the product of the number of episodes per seed and the number of seeds
total_test_episodes = num_episodes_each_seed * len(seeds)
Expand All @@ -49,7 +50,7 @@
# Enable saving of replay as a gif, specify the path to save the replay gif
# main_config.env.save_replay_gif = True
# main_config.env.replay_path_gif = './video'
main_config.env.save_replay=True
main_config.env.save_replay = True

# Initialize lists to store the mean and total returns for each seed
returns_mean_seeds = []
Expand Down
4 changes: 2 additions & 2 deletions zoo/memory/envs/memory_lightzero_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
info['eval_episode_return'] = info['success']
print(f'episode seed:{self._seed} done! self.episode_reward_list is: {self.episode_reward_list}')

# print(f"Action: {action}, Reward: {reward}, Observation: {observation}, Done: {done}, Info: {info}")
print(f"Step: {self._current_step}, Action: {action}, Reward: {reward}, Observation: {observation}, Done: {done}, Info: {info}")
observation = to_ndarray(observation, dtype=np.float32)
reward = to_ndarray([reward])
action_mask = np.ones(self.action_space.n, 'int8')
Expand Down Expand Up @@ -201,7 +201,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
gif_dir = os.path.join(os.path.dirname(__file__), 'replay')
os.makedirs(gif_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
gif_file = os.path.join(gif_dir, f'episode_{self._current_step}_{timestamp}.gif')
gif_file = os.path.join(gif_dir, f'episode_seed{self._seed}_len{self._current_step}_{timestamp}.gif')
self._gif_images[0].save(gif_file, save_all=True, append_images=self._gif_images[1:], duration=100, loop=0)
print(f'saved replay to {gif_file}')

Expand Down
7 changes: 4 additions & 3 deletions zoo/memory/envs/test_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def main():
render=args.render,
scale_observation=True,
flate_observation=False, # Whether to flatten the observation
obs_max_scale=100, # Maximum value of the observation
)

for i in range(args.num_episodes):
Expand All @@ -54,7 +55,8 @@ def main():
done = timestep.done
info = timestep.info
episode_return += reward
print(f"Action: {action}, Reward: {reward}, Done: {done}, Info: {info}")
# print(f"Action: {action}, Reward: {reward}, Observation: {obs}, Done: {done}, Info: {info}")
# print(f"Observation max: {obs['observation'].max()}, min: {obs['observation'].min()}, mean: {obs['observation'].mean()}")

print(f"Episode {i} finished with return: {episode_return}")

Expand All @@ -75,5 +77,4 @@ def get_human_action(env):

if __name__ == '__main__':
main()
# python test_render.py --save_replay --render --mode human
# python test_render.py --save_replay --render --mode random
# python test_render.py

0 comments on commit c3031b0

Please sign in to comment.