From 915597bca4a20e44d35ba35afe51f28f9d5d65d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 26 Feb 2024 16:29:42 +0800 Subject: [PATCH] fix(pu): fix cfg.policy.learn.learner.hook.save_ckpt_after_iter when eval_offline is True --- lzero/entry/train_muzero.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index dbed64015..2a623540f 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -47,7 +47,8 @@ def train_muzero( """ cfg, create_cfg = input_cfg - assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \ + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', + 'stochastic_muzero'], \ "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'" if create_cfg.policy.type == 'muzero': @@ -77,6 +78,9 @@ def train_muzero( evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + if cfg.policy.eval_offline: + cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # load pretrained model @@ -117,7 +121,7 @@ def train_muzero( # ============================================================== # Learner's before_run hook. learner.call_hook('before_run') - + if cfg.policy.update_per_collect is not None: update_per_collect = cfg.policy.update_per_collect @@ -129,7 +133,6 @@ def train_muzero( if cfg.policy.eval_offline: eval_train_iter_list = [] eval_train_envstep_list = [] - cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq while True: log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) @@ -206,7 +209,8 @@ def train_muzero( # load the ckpt of pretrained model policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device)) stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep) - logging.info(f'eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}') + logging.info( + f'eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}') logging.info(f'eval offline finished!') break