diff --git a/run_nerf.py b/run_nerf.py index bc270be86..57fbc79a8 100644 --- a/run_nerf.py +++ b/run_nerf.py @@ -791,12 +791,14 @@ def train(): # Rest is logging if i%args.i_weights==0: path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) - torch.save({ + ckpt = { 'global_step': global_step, 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), - 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 'optimizer_state_dict': optimizer.state_dict(), - }, path) + } + if render_kwargs_train['network_fine'] is not None: + ckpt.update({'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict()}) + torch.save(ckpt, path) print('Saved checkpoints at', path) if i%args.i_video==0 and i > 0: