@@ -170,17 +170,16 @@ def distill_one_step(
170170 noisy_model_input , model_pred , indices , multiphase )
171171
172172 # Get teacher model prediction
173- with torch .no_grad ():
174- with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
175- with set_forward_context (current_timestep = timesteps ,
176- attn_metadata = None ):
177- cond_teacher_output = teacher_transformer (
178- noisy_model_input ,
179- encoder_hidden_states ,
180- timesteps ,
181- encoder_attention_mask ,
182- return_dict = False ,
183- )[0 ].float ()
173+ with torch .no_grad (), torch .autocast ("cuda" , dtype = torch .bfloat16 ):
174+ with set_forward_context (current_timestep = timesteps ,
175+ attn_metadata = None ):
176+ cond_teacher_output = teacher_transformer (
177+ noisy_model_input ,
178+ encoder_hidden_states ,
179+ timesteps ,
180+ encoder_attention_mask ,
181+ return_dict = False ,
182+ )[0 ].float ()
184183
185184 if not_apply_cfg_solver :
186185 uncond_teacher_output = cond_teacher_output
@@ -319,25 +318,23 @@ def forward(
319318 self .training_args .sp_size *
320319 self .training_args .train_sp_batch_size )
321320 logger .info ("***** Running distillation training *****" )
322- logger .info (f" Resume training from step { init_steps } " )
323- logger .info (
324- f" Instantaneous batch size per device = { self .training_args .train_batch_size } "
325- )
321+ logger .info (" Resume training from step %d" , init_steps )
322+ logger .info (" Instantaneous batch size per device = %d" ,
323+ self .training_args .train_batch_size )
326324 logger .info (
327- f" Total train batch size (w. data & sequence parallel, accumulation) = { total_batch_size } "
328- )
329- logger .info (
330- f" Gradient Accumulation steps = { self .training_args .gradient_accumulation_steps } "
331- )
332- logger .info (
333- f" Total optimization steps = { self .training_args .max_train_steps } "
334- )
335- logger .info (
336- f" Total training parameters per FSDP shard = { sum (p .numel () for p in self .transformer .parameters () if p .requires_grad ) / 1e9 } B"
337- )
325+ " Total train batch size (w. data & sequence parallel, accumulation) = %d" ,
326+ total_batch_size )
327+ logger .info (" Gradient Accumulation steps = %d" ,
328+ self .training_args .gradient_accumulation_steps )
329+ logger .info (" Total optimization steps = %d" ,
330+ self .training_args .max_train_steps )
338331 logger .info (
339- f" Master weight dtype: { self .transformer .parameters ().__next__ ().dtype } "
340- )
332+ " Total training parameters per FSDP shard = %.2f B" ,
333+ sum (p .numel ()
334+ for p in self .transformer .parameters () if p .requires_grad ) /
335+ 1e9 )
336+ logger .info (" Master weight dtype: %s" ,
337+ self .transformer .parameters ().__next__ ().dtype )
341338
342339 # Potentially load in the weights and states from a previous save
343340 if self .training_args .resume_from_checkpoint :
0 commit comments