diff --git a/moai/engine/run_callback.py b/moai/engine/run_callback.py index f22ea1d..6737e20 100644 --- a/moai/engine/run_callback.py +++ b/moai/engine/run_callback.py @@ -129,7 +129,7 @@ def on_train_batch_end( Note: The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the loss returned from ``training_step``. """ - if C._MOAI_LOSSES_ in outputs: + if C._MOAI_LOSSES_ in outputs and 'total' in outputs[C._MOAI_LOSSES_]: if losses := toolz.merge( outputs[f"{C._MOAI_LOSSES_}.weighted"], {"total": outputs[f"{C._MOAI_LOSSES_}.total"]},