From a190628970234db336edd759a977b54887aba489 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Tue, 9 Sep 2025 20:56:26 -0700 Subject: [PATCH] small nit fixes [ghstack-poisoned] --- torchtitan/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index e38446a398..8183b8e6df 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -451,7 +451,7 @@ def forward_backward_step( with self.maybe_enable_amp: pred = model_parts[0](inputs, eos_id=self.tokenizer.eos_id) loss = self.loss_fn(pred, labels) - # need to free to before bwd to avoid peaking memory + # need to free pred before bwd to avoid peaking memory del pred loss.backward() @@ -471,7 +471,7 @@ def train_step( accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. - for microbatch in range(self.gradient_accumulation_steps): + for _microbatch in range(self.gradient_accumulation_steps): input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach())