Skip to content

Commit 697ce19

Browse files
committed
precommit format
1 parent f797e79 commit 697ce19

File tree

2 files changed

+26
-29
lines changed

2 files changed

+26
-29
lines changed

fastvideo/v1/training/distillation_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def log_validation(self, transformer, fastvideo_args, global_step):
245245
videos = []
246246
captions = []
247247
for _, embeddings, masks, infos in validation_dataloader:
248-
logger.info(f"infos: {infos}")
248+
logger.info("infos: %s", infos)
249249
caption = infos['caption']
250250
captions.append(caption)
251251
prompt_embeds = embeddings.to(fastvideo_args.device)

fastvideo/v1/training/wan_distillation_pipeline.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -170,17 +170,16 @@ def distill_one_step(
170170
noisy_model_input, model_pred, indices, multiphase)
171171

172172
# Get teacher model prediction
173-
with torch.no_grad():
174-
with torch.autocast("cuda", dtype=torch.bfloat16):
175-
with set_forward_context(current_timestep=timesteps,
176-
attn_metadata=None):
177-
cond_teacher_output = teacher_transformer(
178-
noisy_model_input,
179-
encoder_hidden_states,
180-
timesteps,
181-
encoder_attention_mask,
182-
return_dict=False,
183-
)[0].float()
173+
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
174+
with set_forward_context(current_timestep=timesteps,
175+
attn_metadata=None):
176+
cond_teacher_output = teacher_transformer(
177+
noisy_model_input,
178+
encoder_hidden_states,
179+
timesteps,
180+
encoder_attention_mask,
181+
return_dict=False,
182+
)[0].float()
184183

185184
if not_apply_cfg_solver:
186185
uncond_teacher_output = cond_teacher_output
@@ -319,25 +318,23 @@ def forward(
319318
self.training_args.sp_size *
320319
self.training_args.train_sp_batch_size)
321320
logger.info("***** Running distillation training *****")
322-
logger.info(f" Resume training from step {init_steps}")
323-
logger.info(
324-
f" Instantaneous batch size per device = {self.training_args.train_batch_size}"
325-
)
321+
logger.info(" Resume training from step %d", init_steps)
322+
logger.info(" Instantaneous batch size per device = %d",
323+
self.training_args.train_batch_size)
326324
logger.info(
327-
f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}"
328-
)
329-
logger.info(
330-
f" Gradient Accumulation steps = {self.training_args.gradient_accumulation_steps}"
331-
)
332-
logger.info(
333-
f" Total optimization steps = {self.training_args.max_train_steps}"
334-
)
335-
logger.info(
336-
f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B"
337-
)
325+
" Total train batch size (w. data & sequence parallel, accumulation) = %d",
326+
total_batch_size)
327+
logger.info(" Gradient Accumulation steps = %d",
328+
self.training_args.gradient_accumulation_steps)
329+
logger.info(" Total optimization steps = %d",
330+
self.training_args.max_train_steps)
338331
logger.info(
339-
f" Master weight dtype: {self.transformer.parameters().__next__().dtype}"
340-
)
332+
" Total training parameters per FSDP shard = %.2f B",
333+
sum(p.numel()
334+
for p in self.transformer.parameters() if p.requires_grad) /
335+
1e9)
336+
logger.info(" Master weight dtype: %s",
337+
self.transformer.parameters().__next__().dtype)
341338

342339
# Potentially load in the weights and states from a previous save
343340
if self.training_args.resume_from_checkpoint:

0 commit comments

Comments
 (0)