diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index c6ca734c..62fe6bcc 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -730,12 +730,12 @@ def train_step_logs( validation_elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 validation_tokens_per_sec = ( validation_total_samples * self.sequence_length / (validation_elapsed_time_per_iteration_ms / 1000) - ) # tokens_per_sec is calculated using sequence_length + ) validation_model_tflops, validation_hardware_tflops = self.unwrapped_model.get_flops_per_sec( iteration_time_in_sec=validation_elapsed_time_per_iteration_ms / 1000, sequence_length=self.sequence_length, - global_batch_size=validation_total_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches + global_batch_size=validation_total_samples, ) if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: