Skip to content

Commit

Permalink
fix(pu): fix cfg.policy.learn.learner.hook.save_ckpt_after_iter when …
Browse files Browse the repository at this point in the history
…eval_offline is True
  • Loading branch information
puyuan1996 committed Feb 26, 2024
1 parent 81af4c1 commit 915597b
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def train_muzero(
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero',
'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'"

if create_cfg.policy.type == 'muzero':
Expand Down Expand Up @@ -77,6 +78,9 @@ def train_muzero(
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

if cfg.policy.eval_offline:
cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# load pretrained model
Expand Down Expand Up @@ -117,7 +121,7 @@ def train_muzero(
# ==============================================================
# Learner's before_run hook.
learner.call_hook('before_run')

if cfg.policy.update_per_collect is not None:
update_per_collect = cfg.policy.update_per_collect

Expand All @@ -129,7 +133,6 @@ def train_muzero(
if cfg.policy.eval_offline:
eval_train_iter_list = []
eval_train_envstep_list = []
cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq

while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
Expand Down Expand Up @@ -206,7 +209,8 @@ def train_muzero(
# load the ckpt of pretrained model
policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device))
stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep)
logging.info(f'eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}')
logging.info(
f'eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}')
logging.info(f'eval offline finished!')
break

Expand Down

0 comments on commit 915597b

Please sign in to comment.