From d03365f6d617de11404c642873ea2eaf7e7ec6c0 Mon Sep 17 00:00:00 2001 From: Peiyuan Zhang Date: Wed, 11 Jun 2025 03:05:29 +0000 Subject: [PATCH 01/14] move files --- .../v1/pipelines/{ => preprocess}/preprocess_pipeline_base.py | 0 .../{data_preprocess => v1/pipelines/preprocess}/v1_preprocess.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename fastvideo/v1/pipelines/{ => preprocess}/preprocess_pipeline_base.py (100%) rename fastvideo/{data_preprocess => v1/pipelines/preprocess}/v1_preprocess.py (100%) diff --git a/fastvideo/v1/pipelines/preprocess_pipeline_base.py b/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_base.py similarity index 100% rename from fastvideo/v1/pipelines/preprocess_pipeline_base.py rename to fastvideo/v1/pipelines/preprocess/preprocess_pipeline_base.py diff --git a/fastvideo/data_preprocess/v1_preprocess.py b/fastvideo/v1/pipelines/preprocess/v1_preprocess.py similarity index 100% rename from fastvideo/data_preprocess/v1_preprocess.py rename to fastvideo/v1/pipelines/preprocess/v1_preprocess.py From 12c530ec00d5446ba65fbed4687668de3ab8607b Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Mon, 26 May 2025 17:07:50 +0800 Subject: [PATCH 02/14] [WIP][Feat] distillation --- fastvideo/v1/fastvideo_args.py | 21 +- .../v1/pipelines/composed_pipeline_base.py | 20 +- .../v1/pipelines/distillation_pipeline.py | 766 ++++++++++++++++ fastvideo/v1/pipelines/training_pipeline.py | 820 ++++++++++++++++++ scripts/distill/distill_v1.sh | 47 + scripts/finetune/finetune_v1.sh | 2 +- 6 files changed, 1665 insertions(+), 11 deletions(-) create mode 100644 fastvideo/v1/pipelines/distillation_pipeline.py create mode 100644 fastvideo/v1/pipelines/training_pipeline.py create mode 100644 scripts/distill/distill_v1.sh diff --git a/fastvideo/v1/fastvideo_args.py b/fastvideo/v1/fastvideo_args.py index c763d0ddb..d83ce5642 100644 --- a/fastvideo/v1/fastvideo_args.py +++ b/fastvideo/v1/fastvideo_args.py @@ -34,7 +34,7 @@ class FastVideoArgs: # Distributed executor backend distributed_executor_backend: str = "mp" - inference_mode: bool = True # if False == training mode + mode: str = "inference" # Options: "inference", "training", "distill" # HuggingFace specific parameters trust_remote_code: bool = False @@ -111,7 +111,15 @@ class FastVideoArgs: @property def training_mode(self) -> bool: - return not self.inference_mode + return self.mode == "training" + + @property + def distill_mode(self) -> bool: + return self.mode == "distill" + + @property + def inference_mode(self) -> bool: + return self.mode == "inference" def __post_init__(self): self.check_fastvideo_args() @@ -146,10 +154,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parser.add_argument( - "--inference-mode", - action=StoreBoolean, - default=FastVideoArgs.inference_mode, - help="Whether to use inference mode", + "--mode", + type=str, + default=FastVideoArgs.mode, + choices=["inference", "training", "distill"], + help="The mode to use", ) # HuggingFace specific parameters diff --git a/fastvideo/v1/pipelines/composed_pipeline_base.py b/fastvideo/v1/pipelines/composed_pipeline_base.py index 25fe3bc4b..3d3698e87 100644 --- a/fastvideo/v1/pipelines/composed_pipeline_base.py +++ b/fastvideo/v1/pipelines/composed_pipeline_base.py @@ -94,6 +94,12 @@ def __init__(self, self.initialize_validation_pipeline(self.training_args) self.initialize_training_pipeline(self.training_args) + if fastvideo_args.distill_mode: + self.initialize_distillation_pipeline(fastvideo_args) + + if fastvideo_args.log_validation: + self.initialize_validation_pipeline(fastvideo_args) + self.initialize_pipeline(fastvideo_args) if not fastvideo_args.training_mode: @@ -109,6 +115,10 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs): "if log_validation is True, the pipeline must implement this method" ) + def initialize_distillation_pipeline(self, fastvideo_args: FastVideoArgs): + raise NotImplementedError( + "if distill_mode is True, the pipeline must implement this method") + @classmethod def from_pretrained(cls, model_path: str, @@ -148,9 +158,11 @@ def from_pretrained(cls, config_args = shallow_asdict(config) config_args.update(kwargs) - if args is None or args.inference_mode: - fastvideo_args = FastVideoArgs(model_path=model_path, **config_args) - + if args.mode == "inference": + fastvideo_args = FastVideoArgs(model_path=model_path, + device_str=device or "cuda" if + torch.cuda.is_available() else "cpu", + **config_args) fastvideo_args.model_path = model_path for key, value in config_args.items(): setattr(fastvideo_args, key, value) @@ -164,7 +176,7 @@ def from_pretrained(cls, fastvideo_args.use_cpu_offload = False # make sure we are in training mode - fastvideo_args.inference_mode = False + fastvideo_args.mode = args.mode # we hijack the precision to be the master weight type so that the # model is loaded with the correct precision. Subsequently we will # use FSDP2's MixedPrecisionPolicy to set the precision for the diff --git a/fastvideo/v1/pipelines/distillation_pipeline.py b/fastvideo/v1/pipelines/distillation_pipeline.py new file mode 100644 index 000000000..d3e5537a1 --- /dev/null +++ b/fastvideo/v1/pipelines/distillation_pipeline.py @@ -0,0 +1,766 @@ +import gc +import os +import sys +import time +from abc import ABC, abstractmethod +from collections import deque +from copy import deepcopy + +import imageio +import numpy as np +import torch +import torchvision +import wandb +from diffusers.optimization import get_scheduler +from einops import rearrange +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm.auto import tqdm + +from fastvideo.distill.solver import EulerSolver, extract_into_tensor +from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input +from fastvideo.utils.checkpoint import save_checkpoint_v1 +from fastvideo.v1.configs.sample import SamplingParam +from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset +from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group +from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.v1.forward_context import set_forward_context +from fastvideo.v1.logger import init_logger +from fastvideo.v1.pipelines import ComposedPipelineBase +from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.v1.pipelines.training_utils import ( + _clip_grad_norm_while_handling_failing_dtensor_cases) +from fastvideo.v1.pipelines.wan.wan_pipeline import WanValidationPipeline + +logger = init_logger(__name__) + + +# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [ + i * threshold_noise / linear_steps for i in range(linear_steps) + ] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * + quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / ( + quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const + for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + + +def reshard_fsdp(model): + """Reshard FSDP model for EMA updates.""" + for m in FSDP.fsdp_modules(model): + if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD: + torch.distributed.fsdp._runtime_utils._reshard(m, m._handle, True) + + +def get_norm(model_pred, norms, gradient_accumulation_steps): + """Calculate and aggregate model prediction norms.""" + fro_norm = ( + torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore + gradient_accumulation_steps) + largest_singular_value = (torch.linalg.matrix_norm(model_pred, ord=2) / + gradient_accumulation_steps) + absolute_mean = torch.mean( + torch.abs(model_pred)) / gradient_accumulation_steps + absolute_max = torch.max( + torch.abs(model_pred)) / gradient_accumulation_steps + + sp_group = get_sp_group() + sp_group.all_reduce(fro_norm, op=torch.distributed.ReduceOp.AVG) + sp_group.all_reduce(largest_singular_value, + op=torch.distributed.ReduceOp.AVG) + sp_group.all_reduce(absolute_mean, op=torch.distributed.ReduceOp.AVG) + + norms["fro"] += torch.mean(fro_norm).item() # codespell:ignore + norms["largest singular value"] += torch.mean(largest_singular_value).item() + norms["absolute mean"] += absolute_mean.item() + norms["absolute max"] += absolute_max.item() + + +class DistillationPipeline(ComposedPipelineBase, ABC): + """ + A pipeline for distillation training. All distillation pipelines should inherit from this class. + """ + _required_config_modules = ["scheduler", "transformer"] + + def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): + logger.info("Initializing distillation pipeline...") + self.device = fastvideo_args.device + self.sp_group = get_sp_group() + self.world_size = self.sp_group.world_size + self.rank = self.sp_group.rank + self.local_rank = self.sp_group.local_rank + + # Initialize student model + self.transformer = self.get_module("transformer") + assert self.transformer is not None + + self.transformer.requires_grad_(True) + self.transformer.train() + + # Initialize teacher model + self.teacher_transformer = deepcopy(self.transformer) + self.teacher_transformer.requires_grad_(False) + + # Initialize EMA model if needed + if fastvideo_args.use_ema: + self.ema_transformer = deepcopy(self.transformer) + self.ema_transformer.requires_grad_(False) + else: + self.ema_transformer = None + + args = fastvideo_args + noise_scheduler = self.get_module("scheduler") + assert noise_scheduler is not None + + # Initialize solver for distillation + if args.scheduler_type == "pcm_linear_quadratic": + linear_steps = int(noise_scheduler.config.num_train_timesteps * + args.linear_range) + sigmas = linear_quadratic_schedule( + noise_scheduler.config.num_train_timesteps, + args.linear_quadratic_threshold, + linear_steps, + ) + sigmas = torch.tensor(sigmas).to(dtype=torch.float32) + else: + sigmas = noise_scheduler.sigmas + + self.solver = EulerSolver( + sigmas.numpy()[::-1], + noise_scheduler.config.num_train_timesteps, + euler_timesteps=args.num_euler_timesteps, + ) + self.solver.to(self.device) + + # Setup optimizer + params_to_optimize = self.transformer.parameters() + params_to_optimize = list( + filter(lambda p: p.requires_grad, params_to_optimize)) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=args.learning_rate, + betas=(0.9, 0.999), + weight_decay=args.weight_decay, + eps=1e-8, + ) + + init_steps = 0 + logger.info("optimizer: %s", optimizer) + + # Setup lr scheduler + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * self.world_size, + num_training_steps=args.max_train_steps * self.world_size, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + last_epoch=init_steps - 1, + ) + + # Setup dataset + train_dataset = ParquetVideoTextDataset( + args.data_path, + batch_size=args.train_batch_size, + rank=self.rank, + world_size=self.world_size, + cfg_rate=args.cfg, + num_latent_t=args.num_latent_t) + + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + prefetch_factor=2, + shuffle=False, + pin_memory=True, + drop_last=True) + + self.lr_scheduler = lr_scheduler + self.train_dataset = train_dataset + self.train_dataloader = train_dataloader + self.init_steps = init_steps + self.optimizer = optimizer + self.noise_scheduler = noise_scheduler + + # Get unconditional embeddings + self.uncond_prompt_embed = train_dataset.uncond_prompt_embed + self.uncond_prompt_mask = train_dataset.uncond_prompt_mask + + if self.rank <= 0: + project = args.tracker_project_name or "fastvideo" + wandb.init(project=project, config=args) + + @abstractmethod + def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): + raise NotImplementedError( + "Distillation pipelines must implement this method") + + @abstractmethod + def distill_one_step(self, transformer, model_type, teacher_transformer, + ema_transformer, optimizer, lr_scheduler, loader_iter, + noise_scheduler, solver, noise_random_generator, + gradient_accumulation_steps, sp_size, max_grad_norm, + uncond_prompt_embed, uncond_prompt_mask, + num_euler_timesteps, multiphase, not_apply_cfg_solver, + distill_cfg, ema_decay, pred_decay_weight, + pred_decay_type, hunyuan_teacher_disable_cfg): + """ + Distill one step of the model. + """ + raise NotImplementedError( + "Distillation pipeline must implement this method") + + def log_validation(self, transformer, fastvideo_args, global_step): + """Log validation results during training.""" + fastvideo_args.inference_mode = True + fastvideo_args.use_cpu_offload = False + if not fastvideo_args.log_validation: + return + if self.validation_pipeline is None: + raise ValueError("Validation pipeline is not set") + + # Create sampling parameters if not provided + sampling_param = SamplingParam.from_pretrained( + fastvideo_args.model_path) + + # Prepare validation prompts + validation_dataset = ParquetVideoTextDataset( + fastvideo_args.validation_prompt_dir, + batch_size=1, + rank=0, + world_size=1, + cfg_rate=0, + num_latent_t=fastvideo_args.num_latent_t) + + validation_dataloader = StatefulDataLoader(validation_dataset, + batch_size=1, + num_workers=1, + prefetch_factor=2, + shuffle=False, + pin_memory=True, + drop_last=False) + + transformer.requires_grad_(False) + for p in transformer.parameters(): + p.requires_grad = False + transformer.eval() + + # Add the transformer to the validation pipeline + self.validation_pipeline.add_module("transformer", transformer) + self.validation_pipeline.latent_preparation_stage.transformer = transformer + self.validation_pipeline.denoising_stage.transformer = transformer + + # Process validation prompts + videos = [] + captions = [] + for _, embeddings, masks, infos in validation_dataloader: + logger.info(f"infos: {infos}") + caption = infos['caption'] + captions.append(caption) + prompt_embeds = embeddings.to(fastvideo_args.device) + prompt_attention_mask = masks.to(fastvideo_args.device) + + # Calculate sizes + latents_size = [(sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, + sampling_param.width // 8] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + + # Prepare batch for validation + batch = ForwardBatch( + data_type="video", + latents=None, + prompt_embeds=[prompt_embeds], + prompt_attention_mask=[prompt_attention_mask], + height=fastvideo_args.num_height, + width=fastvideo_args.num_width, + num_frames=fastvideo_args.num_frames, + num_inference_steps=10, + guidance_scale=1, + n_tokens=n_tokens, + do_classifier_free_guidance=False, + eta=0.0, + extra={}, + ) + + # Run validation inference + with torch.inference_mode(): + output_batch = self.validation_pipeline.forward( + batch, fastvideo_args) + samples = output_batch.output + + # Process outputs + video = rearrange(samples, "b c t h w -> t b c h w") + frames = [] + for x in video: + x = torchvision.utils.make_grid(x, nrow=6) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + frames.append((x * 255).numpy().astype(np.uint8)) + videos.append(frames) + + # Log validation results + if self.rank == 0: + video_filenames = [] + video_captions = [] + for i, video in enumerate(videos): + caption = captions[i] + filename = os.path.join( + fastvideo_args.output_dir, + f"validation_step_{global_step}_video_{i}.mp4") + imageio.mimsave(filename, video, fps=sampling_param.fps) + video_filenames.append(filename) + video_captions.append(caption) + + logs = { + "validation_videos": [ + wandb.Video(filename, + caption=caption) for filename, caption in zip( + video_filenames, video_captions) + ] + } + wandb.log(logs, step=global_step) + + # Re-enable gradients for training + transformer.requires_grad_(True) + transformer.train() + + gc.collect() + torch.cuda.empty_cache() + + +class WanDistillationPipeline(DistillationPipeline): + """ + A distillation pipeline for Wan. + """ + _required_config_modules = ["scheduler", "transformer"] + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + pass + + def create_training_stages(self, fastvideo_args: FastVideoArgs): + pass + + def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): + logger.info("Initializing validation pipeline...") + args_copy = deepcopy(fastvideo_args) + + args_copy.mode = "inference" + args_copy.vae_config.load_encoder = False + validation_pipeline = WanValidationPipeline.from_pretrained( + fastvideo_args.model_path, args=args_copy) + + self.validation_pipeline = validation_pipeline + + def distill_one_step( + self, + transformer, + model_type, + teacher_transformer, + ema_transformer, + optimizer, + lr_scheduler, + loader_iter, + noise_scheduler, + solver, + noise_random_generator, + gradient_accumulation_steps, + sp_size, + max_grad_norm, + uncond_prompt_embed, + uncond_prompt_mask, + num_euler_timesteps, + multiphase, + not_apply_cfg_solver, + distill_cfg, + ema_decay, + pred_decay_weight, + pred_decay_type, + hunyuan_teacher_disable_cfg, + ): + """Perform one step of distillation training.""" + total_loss = 0.0 + optimizer.zero_grad() + model_pred_norm = { + "fro": 0.0, # codespell:ignore + "largest singular value": 0.0, + "absolute mean": 0.0, + "absolute max": 0.0, + } + + for _ in range(gradient_accumulation_steps): + ( + latents, + encoder_hidden_states, + encoder_attention_mask, + infos, + ) = next(loader_iter) + + latents = latents.to(self.device, dtype=torch.bfloat16) + encoder_hidden_states = encoder_hidden_states.to( + self.device, dtype=torch.bfloat16) + + model_input = normalize_dit_input(model_type, latents) + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + index = torch.randint(0, + num_euler_timesteps, (bsz, ), + device=model_input.device).long() + if sp_size > 1: + self.sp_group.broadcast(index, src=0) + + # Add noise according to flow matching + sigmas = extract_into_tensor(solver.sigmas, index, + model_input.shape) + sigmas_prev = extract_into_tensor(solver.sigmas_prev, index, + model_input.shape) + + timesteps = (sigmas * + noise_scheduler.config.num_train_timesteps).view(-1) + timesteps_prev = ( + sigmas_prev * + noise_scheduler.config.num_train_timesteps).view(-1) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input + + # Get student model prediction + with torch.autocast("cuda", dtype=torch.bfloat16): + input_kwargs = { + "hidden_states": noisy_model_input, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timesteps, + "encoder_attention_mask": encoder_attention_mask, + "return_dict": False, + } + if hunyuan_teacher_disable_cfg: + input_kwargs["guidance"] = torch.tensor( + [1000.0], + device=noisy_model_input.device, + dtype=torch.bfloat16) + + with set_forward_context(current_timestep=timesteps, + attn_metadata=None): + model_pred = transformer(**input_kwargs)[0] + + # Apply multi-phase prediction + model_pred, end_index = solver.euler_style_multiphase_pred( + noisy_model_input, model_pred, index, multiphase) + + # Get teacher model guidance + with torch.no_grad(): + w = distill_cfg + with torch.autocast("cuda", dtype=torch.bfloat16): + with set_forward_context(current_timestep=timesteps, + attn_metadata=None): + cond_teacher_output = teacher_transformer( + noisy_model_input, + encoder_hidden_states, + timesteps, + encoder_attention_mask, + return_dict=False, + )[0].float() + + if not_apply_cfg_solver: + uncond_teacher_output = cond_teacher_output + else: + # Get teacher model prediction on unconditional embedding + with torch.autocast("cuda", dtype=torch.bfloat16): + with set_forward_context(current_timestep=timesteps, + attn_metadata=None): + uncond_teacher_output = teacher_transformer( + noisy_model_input, + uncond_prompt_embed.unsqueeze(0).expand( + bsz, -1, -1), + timesteps, + uncond_prompt_mask.unsqueeze(0).expand(bsz, -1), + return_dict=False, + )[0].float() + + teacher_output = uncond_teacher_output + w * ( + cond_teacher_output - uncond_teacher_output) + x_prev = solver.euler_step(noisy_model_input, teacher_output, + index) + + # Get target prediction + with torch.no_grad(): + with torch.autocast("cuda", dtype=torch.bfloat16): + if ema_transformer is not None: + with set_forward_context( + current_timestep=timesteps_prev, + attn_metadata=None): + target_pred = ema_transformer( + x_prev.float(), + encoder_hidden_states, + timesteps_prev, + encoder_attention_mask, + return_dict=False, + )[0] + else: + with set_forward_context( + current_timestep=timesteps_prev, + attn_metadata=None): + target_pred = transformer( + x_prev.float(), + encoder_hidden_states, + timesteps_prev, + encoder_attention_mask, + return_dict=False, + )[0] + + target, end_index = solver.euler_style_multiphase_pred( + x_prev, target_pred, index, multiphase, True) + + # Calculate loss + huber_c = 0.001 + loss = (torch.mean( + torch.sqrt((model_pred.float() - target.float())**2 + + huber_c**2) - huber_c) / gradient_accumulation_steps) + + if pred_decay_weight > 0: + if pred_decay_type == "l1": + pred_decay_loss = ( + torch.mean(torch.sqrt(model_pred.float()**2)) * + pred_decay_weight / gradient_accumulation_steps) + loss += pred_decay_loss + elif pred_decay_type == "l2": + pred_decay_loss = (torch.mean(model_pred.float()**2) * + pred_decay_weight / + gradient_accumulation_steps) + loss += pred_decay_loss + else: + raise NotImplementedError( + "pred_decay_type is not implemented") + + # Calculate model prediction norms + get_norm(model_pred.detach().float(), model_pred_norm, + gradient_accumulation_steps) + loss.backward() + + avg_loss = loss.detach().clone() + self.sp_group.all_reduce(avg_loss, + op=torch.distributed.ReduceOp.AVG) + total_loss += avg_loss.item() + + # Update EMA + if ema_transformer is not None: + reshard_fsdp(ema_transformer) + for p_averaged, p_model in zip(ema_transformer.parameters(), + transformer.parameters()): + with torch.no_grad(): + p_averaged.copy_( + torch.lerp(p_averaged.detach(), p_model.detach(), + 1 - ema_decay)) + + # Gradient clipping and optimization step + model_parts = [transformer] + grad_norm = _clip_grad_norm_while_handling_failing_dtensor_cases( + [p for m in model_parts for p in m.parameters()], + max_grad_norm, + foreach=None, + ) + + optimizer.step() + lr_scheduler.step() + + return total_loss, grad_norm.item(), model_pred_norm + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: TrainingArgs, + ): + args = fastvideo_args + train_dataloader = self.train_dataloader + init_steps = self.init_steps + lr_scheduler = self.lr_scheduler + optimizer = self.optimizer + noise_scheduler = self.noise_scheduler + solver = self.solver + noise_random_generator = None + uncond_prompt_embed = self.uncond_prompt_embed + uncond_prompt_mask = self.uncond_prompt_mask + + # Train! + total_batch_size = (self.world_size * args.gradient_accumulation_steps / + args.sp_size * args.train_sp_batch_size) + logger.info("***** Running distillation training *****") + logger.info(f" Resume training from step {init_steps}") + logger.info( + f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info( + f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {args.max_train_steps}") + logger.info( + f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B" + ) + logger.info( + f" Master weight dtype: {self.transformer.parameters().__next__().dtype}" + ) + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + raise NotImplementedError( + "resume_from_checkpoint is not supported now.") + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=init_steps, + desc="Steps", + disable=self.local_rank > 0, + ) + + loader_iter = iter(train_dataloader) + step_times = deque(maxlen=100) + + # Skip steps if resuming + for i in range(init_steps): + next(loader_iter) + + def get_num_phases(multi_phased_distill_schedule, step): + # step-phase,step-phase + multi_phases = multi_phased_distill_schedule.split(",") + phase = multi_phases[-1].split("-")[-1] + for step_phases in multi_phases: + phase_step, phase = step_phases.split("-") + if step <= int(phase_step): + return int(phase) + return int(phase) + + for step in range(init_steps + 1, args.max_train_steps + 1): + start_time = time.perf_counter() + + assert args.multi_phased_distill_schedule is not None + num_phases = get_num_phases(args.multi_phased_distill_schedule, + step) + + loss, grad_norm, pred_norm = self.distill_one_step( + self.transformer, + "wan", # model_type + self.teacher_transformer, + self.ema_transformer, + optimizer, + lr_scheduler, + loader_iter, + noise_scheduler, + solver, + noise_random_generator, + args.gradient_accumulation_steps, + args.sp_size, + args.max_grad_norm, + uncond_prompt_embed, + uncond_prompt_mask, + args.num_euler_timesteps, + num_phases, + args.not_apply_cfg_solver, + args.distill_cfg, + args.ema_decay, + args.pred_decay_weight, + args.pred_decay_type, + args.hunyuan_teacher_disable_cfg, + ) + + step_time = time.perf_counter() - start_time + step_times.append(step_time) + avg_step_time = sum(step_times) / len(step_times) + + progress_bar.set_postfix({ + "loss": f"{loss:.4f}", + "step_time": f"{step_time:.2f}s", + "grad_norm": grad_norm, + "phases": num_phases, + }) + progress_bar.update(1) + + if self.rank <= 0: + wandb.log( + { + "train_loss": + loss, + "learning_rate": + lr_scheduler.get_last_lr()[0], + "step_time": + step_time, + "avg_step_time": + avg_step_time, + "grad_norm": + grad_norm, + "pred_fro_norm": + pred_norm["fro"], # codespell:ignore + "pred_largest_singular_value": + pred_norm["largest singular value"], + "pred_absolute_mean": + pred_norm["absolute mean"], + "pred_absolute_max": + pred_norm["absolute max"], + "phases": + num_phases, + }, + step=step, + ) + + if step % args.checkpointing_steps == 0: + if args.use_lora: + raise NotImplementedError("LoRA is not supported now") + else: + if args.use_ema: + save_checkpoint_v1(self.ema_transformer, self.rank, + args.output_dir, step) + else: + save_checkpoint_v1(self.transformer, self.rank, + args.output_dir, step) + self.sp_group.barrier() + + if args.log_validation and step % args.validation_steps == 0: + self.log_validation(self.transformer, args, step) + + # Final checkpoint + if args.use_lora: + raise NotImplementedError("LoRA is not supported now") + else: + save_checkpoint_v1(self.transformer, self.rank, args.output_dir, + args.max_train_steps) + + if get_sp_group(): + cleanup_dist_env_and_memory() + + +def main(args): + logger.info("Starting distillation pipeline...") + + pipeline = WanDistillationPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + + args = pipeline.fastvideo_args + pipeline.forward(None, args) + logger.info("Distillation pipeline done") + + +if __name__ == "__main__": + argv = sys.argv + from fastvideo.v1.fastvideo_args import TrainingArgs + from fastvideo.v1.utils import FlexibleArgumentParser + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + args.use_cpu_offload = False + print(args) + main(args) diff --git a/fastvideo/v1/pipelines/training_pipeline.py b/fastvideo/v1/pipelines/training_pipeline.py new file mode 100644 index 000000000..a8bffa368 --- /dev/null +++ b/fastvideo/v1/pipelines/training_pipeline.py @@ -0,0 +1,820 @@ +import gc +import os +import sys +import time +import traceback +from abc import ABC, abstractmethod +from collections import deque +from copy import deepcopy + +import imageio +import numpy as np +import torch +import torchvision +# import torch.distributed as dist +import wandb +from diffusers.optimization import get_scheduler +from einops import rearrange +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm.auto import tqdm + +from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input +from fastvideo.utils.checkpoint import save_checkpoint_v1 +from fastvideo.v1.configs.sample import SamplingParam +from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset +from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group +from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.v1.forward_context import set_forward_context +from fastvideo.v1.logger import init_logger +from fastvideo.v1.pipelines import ComposedPipelineBase +from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.v1.pipelines.training_utils import ( + _clip_grad_norm_while_handling_failing_dtensor_cases, + compute_density_for_timestep_sampling, get_sigmas) +from fastvideo.v1.pipelines.wan.wan_pipeline import WanValidationPipeline + +logger = init_logger(__name__) + +# Manual gradient checking flag - set to True to enable gradient verification +ENABLE_GRADIENT_CHECK = False +GRADIENT_CHECK_DTYPE = torch.bfloat16 + + +class TrainingPipeline(ComposedPipelineBase, ABC): + """ + A pipeline for training a model. All training pipelines should inherit from this class. + All reusable components and code should be implemented in this class. + """ + _required_config_modules = ["scheduler", "transformer"] + + def initialize_training_pipeline(self, fastvideo_args: TrainingArgs): + logger.info("Initializing training pipeline...") + self.device = fastvideo_args.device + self.sp_group = get_sp_group() + self.world_size = self.sp_group.world_size + self.rank = self.sp_group.rank + self.local_rank = self.sp_group.local_rank + self.transformer = self.get_module("transformer") + assert self.transformer is not None + + self.transformer.requires_grad_(True) + self.transformer.train() + + args = fastvideo_args + + noise_scheduler = self.modules["scheduler"] + params_to_optimize = self.transformer.parameters() + params_to_optimize = list( + filter(lambda p: p.requires_grad, params_to_optimize)) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=args.learning_rate, + betas=(0.9, 0.999), + weight_decay=args.weight_decay, + eps=1e-8, + ) + + init_steps = 0 + logger.info("optimizer: %s", optimizer) + + # todo add lr scheduler + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * self.world_size, + num_training_steps=args.max_train_steps * self.world_size, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + last_epoch=init_steps - 1, + ) + + train_dataset = ParquetVideoTextDataset( + args.data_path, + batch_size=args.train_batch_size, + rank=self.rank, + world_size=self.world_size, + cfg_rate=args.cfg, + num_latent_t=args.num_latent_t) + + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=args.train_batch_size, + num_workers=args. + dataloader_num_workers, # Reduce number of workers to avoid memory issues + prefetch_factor=2, + shuffle=False, + pin_memory=True, + drop_last=True) + + self.lr_scheduler = lr_scheduler + self.train_dataset = train_dataset + self.train_dataloader = train_dataloader + self.init_steps = init_steps + self.optimizer = optimizer + self.noise_scheduler = noise_scheduler + # self.noise_random_generator = noise_random_generator + + # num_update_steps_per_epoch = math.ceil( + # len(train_dataloader) / args.gradient_accumulation_steps * + # args.sp_size / args.train_sp_batch_size) + # args.num_train_epochs = math.ceil(args.max_train_steps / + # num_update_steps_per_epoch) + + if self.rank <= 0: + project = args.tracker_project_name or "fastvideo" + wandb.init(project=project, config=args) + + @abstractmethod + def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): + raise NotImplementedError( + "Training pipelines must implement this method") + + @abstractmethod + def train_one_step(self, transformer, model_type, optimizer, lr_scheduler, + loader, noise_scheduler, noise_random_generator, + gradient_accumulation_steps, sp_size, + precondition_outputs, max_grad_norm, weighting_scheme, + logit_mean, logit_std, mode_scale): + """ + Train one step of the model. + """ + raise NotImplementedError( + "Training pipeline must implement this method") + + def log_validation(self, transformer, fastvideo_args, global_step): + fastvideo_args.inference_mode = True + fastvideo_args.use_cpu_offload = False + if not fastvideo_args.log_validation: + return + if self.validation_pipeline is None: + raise ValueError("Validation pipeline is not set") + + # Create sampling parameters if not provided + sampling_param = SamplingParam.from_pretrained( + fastvideo_args.model_path) + + # Prepare validation prompts + print('fastvideo_args.validation_prompt_dir', + fastvideo_args.validation_prompt_dir) + validation_dataset = ParquetVideoTextDataset( + fastvideo_args.validation_prompt_dir, + batch_size=1, + rank=0, + world_size=1, + cfg_rate=0, + num_latent_t=args.num_latent_t) + + validation_dataloader = StatefulDataLoader( + validation_dataset, + batch_size=1, + num_workers=1, # Reduce number of workers to avoid memory issues + prefetch_factor=2, + shuffle=False, + pin_memory=True, + drop_last=False) + + transformer.requires_grad_(False) + for p in transformer.parameters(): + p.requires_grad = False + transformer.eval() + + # Add the transformer to the validation pipeline + self.validation_pipeline.add_module("transformer", transformer) + self.validation_pipeline.latent_preparation_stage.transformer = transformer + self.validation_pipeline.denoising_stage.transformer = transformer + + # Process each validation prompt + videos = [] + captions = [] + for _, embeddings, masks, infos in validation_dataloader: + logger.info(f"infos: {infos}") + caption = infos['caption'] + captions.append(caption) + prompt_embeds = embeddings.to(fastvideo_args.device) + prompt_attention_mask = masks.to(fastvideo_args.device) + + # Calculate sizes + latents_size = [(sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, + sampling_param.width // 8] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + + # Prepare batch for validation + # print('shape of embeddings', prompt_embeds.shape) + batch = ForwardBatch( + # **shallow_asdict(sampling_param), + data_type="video", + latents=None, + # seed=sampling_param.seed, + # data_type="video", + prompt_embeds=[prompt_embeds], + prompt_attention_mask=[prompt_attention_mask], + # make sure we use the same height, width, and num_frames as the training pipeline + height=args.num_height, + width=args.num_width, + num_frames=args.num_frames, + # num_inference_steps=fastvideo_args.validation_sampling_steps, + num_inference_steps=10, + # guidance_scale=fastvideo_args.validation_guidance_scale, + guidance_scale=1, + n_tokens=n_tokens, + do_classifier_free_guidance=False, + eta=0.0, + extra={}, + ) + + # Run validation inference + with torch.inference_mode(): + output_batch = self.validation_pipeline.forward( + batch, fastvideo_args) + samples = output_batch.output + + # Process outputs + video = rearrange(samples, "b c t h w -> t b c h w") + frames = [] + for x in video: + x = torchvision.utils.make_grid(x, nrow=6) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + frames.append((x * 255).numpy().astype(np.uint8)) + videos.append(frames) + + # Log validation results + rank = int(os.environ.get("RANK", 0)) + + if rank == 0: + video_filenames = [] + video_captions = [] + for i, video in enumerate(videos): + caption = captions[i] + filename = os.path.join( + fastvideo_args.output_dir, + f"validation_step_{global_step}_video_{i}.mp4") + imageio.mimsave(filename, video, fps=sampling_param.fps) + video_filenames.append(filename) + video_captions.append( + caption) # Store the caption for each video + + logs = { + "validation_videos": [ + wandb.Video(filename, + caption=caption) for filename, caption in zip( + video_filenames, video_captions) + ] + } + wandb.log(logs, step=global_step) + + # Re-enable gradients for training + transformer.requires_grad_(True) + transformer.train() + + gc.collect() + torch.cuda.empty_cache() + + def gradient_check_parameters(self, + transformer, + latents, + encoder_hidden_states, + encoder_attention_mask, + timesteps, + target, + eps=5e-2, + max_params_to_check=2000): + """ + Verify gradients using finite differences for FSDP models with GRADIENT_CHECK_DTYPE. + Uses standard tolerances for GRADIENT_CHECK_DTYPE precision. + """ + # Move all inputs to CPU and clear GPU memory + inputs_cpu = { + 'latents': latents.cpu(), + 'encoder_hidden_states': encoder_hidden_states.cpu(), + 'encoder_attention_mask': encoder_attention_mask.cpu(), + 'timesteps': timesteps.cpu(), + 'target': target.cpu() + } + del latents, encoder_hidden_states, encoder_attention_mask, timesteps, target + torch.cuda.empty_cache() + + def compute_loss(): + # Move inputs to GPU, compute loss, cleanup + inputs_gpu = { + k: + v.to(self.fastvideo_args.device, + dtype=GRADIENT_CHECK_DTYPE + if k != 'encoder_attention_mask' else None) + for k, v in inputs_cpu.items() + } + + # Use GRADIENT_CHECK_DTYPE for more accurate gradient checking + # with torch.autocast(enabled=False, device_type="cuda"): + with torch.autocast("cuda", dtype=GRADIENT_CHECK_DTYPE): + with set_forward_context( + current_timestep=inputs_gpu['timesteps'], + attn_metadata=None): + model_pred = transformer( + hidden_states=inputs_gpu['latents'], + encoder_hidden_states=inputs_gpu[ + 'encoder_hidden_states'], + timestep=inputs_gpu['timesteps'], + encoder_attention_mask=inputs_gpu[ + 'encoder_attention_mask'], + return_dict=False)[0] + + if self.fastvideo_args.precondition_outputs: + sigmas = get_sigmas(self.noise_scheduler, + inputs_gpu['latents'].device, + inputs_gpu['timesteps'], + n_dim=inputs_gpu['latents'].ndim, + dtype=inputs_gpu['latents'].dtype) + model_pred = inputs_gpu['latents'] - model_pred * sigmas + target_adjusted = inputs_gpu['target'] + else: + target_adjusted = inputs_gpu['target'] + + loss = torch.mean((model_pred - target_adjusted)**2) + + # Cleanup and return + loss_cpu = loss.cpu() + del inputs_gpu, model_pred, target_adjusted + if 'sigmas' in locals(): del sigmas + torch.cuda.empty_cache() + return loss_cpu.to(self.fastvideo_args.device) + + try: + # Get analytical gradients + transformer.zero_grad() + analytical_loss = compute_loss() + analytical_loss.backward() + + # Check gradients for selected parameters + absolute_errors = [] + param_count = 0 + + for name, param in transformer.named_parameters(): + if not (param.requires_grad and param.grad is not None + and param_count < max_params_to_check + and param.grad.abs().max() > 5e-4): + continue + + # Get local parameter and gradient tensors + local_param = param._local_tensor if hasattr( + param, '_local_tensor') else param + local_grad = param.grad._local_tensor if hasattr( + param.grad, '_local_tensor') else param.grad + + # Find first significant gradient element + flat_param = local_param.data.view(-1) + flat_grad = local_grad.view(-1) + check_idx = next((i for i in range(min(10, flat_param.numel())) + if abs(flat_grad[i]) > 1e-4), 0) + + # Store original values + orig_value = flat_param[check_idx].item() + analytical_grad = flat_grad[check_idx].item() + + # Compute numerical gradient + for delta in [eps, -eps]: + with torch.no_grad(): + flat_param[check_idx] = orig_value + delta + loss = compute_loss() + if delta > 0: loss_plus = loss.item() + else: loss_minus = loss.item() + + # Restore parameter and compute error + with torch.no_grad(): + flat_param[check_idx] = orig_value + + numerical_grad = (loss_plus - loss_minus) / (2 * eps) + abs_error = abs(analytical_grad - numerical_grad) + rel_error = abs_error / max(abs(analytical_grad), + abs(numerical_grad), 1e-3) + absolute_errors.append(abs_error) + + logger.info( + f"{name}[{check_idx}]: analytical={analytical_grad:.6f}, " + f"numerical={numerical_grad:.6f}, abs_error={abs_error:.2e}, rel_error={rel_error:.2%}" + ) + + # param_count += 1 + + # Compute and log statistics + if absolute_errors: + min_err, max_err, mean_err = min(absolute_errors), max( + absolute_errors + ), sum(absolute_errors) / len(absolute_errors) + logger.info( + f"Gradient check stats: min={min_err:.2e}, max={max_err:.2e}, mean={mean_err:.2e}" + ) + + if self.rank <= 0: + wandb.log({ + "grad_check/min_abs_error": + min_err, + "grad_check/max_abs_error": + max_err, + "grad_check/mean_abs_error": + mean_err, + "grad_check/analytical_loss": + analytical_loss.item(), + }) + return max_err + + return float('inf') + + except Exception as e: + logger.error(f"Gradient check failed: {e}") + traceback.print_exc() + return float('inf') + + def setup_gradient_check(self, args, loader_iter, noise_scheduler, + noise_random_generator): + """ + Setup and perform gradient check on a fresh batch. + Args: + args: Training arguments + loader_iter: Data loader iterator + noise_scheduler: Noise scheduler for diffusion + noise_random_generator: Random number generator for noise + Returns: + float or None: Maximum gradient error or None if check is disabled/fails + """ + if not ENABLE_GRADIENT_CHECK: + return None + + try: + # Get a fresh batch and process it exactly like train_one_step + check_latents, check_encoder_hidden_states, check_encoder_attention_mask, check_infos = next( + loader_iter) + + # Process exactly like in train_one_step but use GRADIENT_CHECK_DTYPE + check_latents = check_latents.to(self.fastvideo_args.device, + dtype=GRADIENT_CHECK_DTYPE) + check_encoder_hidden_states = check_encoder_hidden_states.to( + self.fastvideo_args.device, dtype=GRADIENT_CHECK_DTYPE) + check_latents = normalize_dit_input("wan", check_latents) + batch_size = check_latents.shape[0] + check_noise = torch.randn_like(check_latents) + + check_u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=batch_size, + generator=noise_random_generator, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + check_indices = (check_u * + noise_scheduler.config.num_train_timesteps).long() + check_timesteps = noise_scheduler.timesteps[check_indices].to( + device=check_latents.device) + + check_sigmas = get_sigmas( + noise_scheduler, + check_latents.device, + check_timesteps, + n_dim=check_latents.ndim, + dtype=check_latents.dtype, + ) + check_noisy_model_input = ( + 1.0 - check_sigmas) * check_latents + check_sigmas * check_noise + + # Compute target exactly like train_one_step + if args.precondition_outputs: + check_target = check_latents + else: + check_target = check_noise - check_latents + + # Perform gradient check with the exact same inputs as training + max_grad_error = self.gradient_check_parameters( + transformer=self.transformer, + latents= + check_noisy_model_input, # Use noisy input like in training + encoder_hidden_states=check_encoder_hidden_states, + encoder_attention_mask=check_encoder_attention_mask, + timesteps=check_timesteps, + target=check_target, + max_params_to_check=100 # Check more parameters + ) + + if max_grad_error > 5e-2: + logger.error( + f"❌ Large gradient error detected: {max_grad_error:.2e}") + else: + logger.info( + f"✅ Gradient check passed: max error {max_grad_error:.2e}") + + return max_grad_error + + except Exception as e: + logger.error(f"Gradient check setup failed: {e}") + traceback.print_exc() + return None + + +class WanTrainingPipeline(TrainingPipeline): + """ + A training pipeline for Wan. + """ + _required_config_modules = ["scheduler", "transformer"] + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + pass + + def create_training_stages(self, fastvideo_args: FastVideoArgs): + pass + + def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): + logger.info("Initializing validation pipeline...") + args_copy = deepcopy(fastvideo_args) + + args_copy.mode = "inference" + args_copy.vae_config.load_encoder = False + validation_pipeline = WanValidationPipeline.from_pretrained( + args.model_path, args=args_copy) + + self.validation_pipeline = validation_pipeline + + def train_one_step( + self, + transformer, + model_type, + optimizer, + lr_scheduler, + loader_iter, + noise_scheduler, + noise_random_generator, + gradient_accumulation_steps, + sp_size, + precondition_outputs, + max_grad_norm, + weighting_scheme, + logit_mean, + logit_std, + mode_scale, + ): + self.modules["transformer"].requires_grad_(True) + self.modules["transformer"].train() + + total_loss = 0.0 + optimizer.zero_grad() + for _ in range(gradient_accumulation_steps): + ( + latents, + encoder_hidden_states, + encoder_attention_mask, + infos, + ) = next(loader_iter) + latents = latents.to(self.fastvideo_args.device, + dtype=torch.bfloat16) + encoder_hidden_states = encoder_hidden_states.to( + self.fastvideo_args.device, dtype=torch.bfloat16) + latents = normalize_dit_input(model_type, latents) + batch_size = latents.shape[0] + noise = torch.randn_like(latents) + u = compute_density_for_timestep_sampling( + weighting_scheme=weighting_scheme, + batch_size=batch_size, + generator=noise_random_generator, + logit_mean=logit_mean, + logit_std=logit_std, + mode_scale=mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + timesteps = noise_scheduler.timesteps[indices].to( + device=latents.device) + if sp_size > 1: + # Make sure that the timesteps are the same across all sp processes. + sp_group = get_sp_group() + sp_group.broadcast(timesteps, src=0) + sigmas = get_sigmas( + noise_scheduler, + latents.device, + timesteps, + n_dim=latents.ndim, + dtype=latents.dtype, + ) + noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise + print('device before forward ', + next(transformer.named_parameters())[1].device) + with torch.autocast("cuda", dtype=torch.bfloat16): + input_kwargs = { + "hidden_states": noisy_model_input, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timesteps, + "encoder_attention_mask": encoder_attention_mask, # B, L + "return_dict": False, + } + if 'hunyuan' in model_type: + input_kwargs["guidance"] = torch.tensor( + [1000.0], + device=noisy_model_input.device, + dtype=torch.bfloat16) + with set_forward_context(current_timestep=timesteps, + attn_metadata=None): + model_pred = transformer(**input_kwargs)[0] + + if precondition_outputs: + model_pred = noisy_model_input - model_pred * sigmas + if precondition_outputs: + target = latents + else: + target = noise - latents + + loss = (torch.mean((model_pred.float() - target.float())**2) / + gradient_accumulation_steps) + print('device before backwardin context', + next(transformer.named_parameters())[1].device) + + print('device before backward out context', + next(transformer.named_parameters())[1].device) + loss.backward() + print('device after backward out context', + next(transformer.named_parameters())[1].device) + + avg_loss = loss.detach().clone() + sp_group = get_sp_group() + sp_group.all_reduce(avg_loss, op=torch.distributed.ReduceOp.AVG) + total_loss += avg_loss.item() + + model_parts = [self.transformer] + grad_norm = _clip_grad_norm_while_handling_failing_dtensor_cases( + [p for m in model_parts for p in m.parameters()], + max_grad_norm, + foreach=None, + ) + + optimizer.step() + print('device after optimizer step', + next(transformer.named_parameters())[1].device) + lr_scheduler.step() + print('device after scheduler step', + next(transformer.named_parameters())[1].device) + return total_loss, grad_norm.item() + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ): + args = fastvideo_args + self.fastvideo_args = args + train_dataloader = self.train_dataloader + init_steps = self.init_steps + lr_scheduler = self.lr_scheduler + optimizer = self.optimizer + noise_scheduler = self.noise_scheduler + noise_random_generator = None + + from diffusers import FlowMatchEulerDiscreteScheduler + noise_scheduler = FlowMatchEulerDiscreteScheduler() + + # Train! + total_batch_size = (self.world_size * args.gradient_accumulation_steps / + args.sp_size * args.train_sp_batch_size) + logger.info("***** Running training *****") + # logger.info(f" Num examples = {len(train_dataset)}") + # logger.info(f" Dataloader size = {len(train_dataloader)}") + # logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Resume training from step {init_steps}") + logger.info( + f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info( + f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {args.max_train_steps}") + logger.info( + f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B" + ) + # print dtype + logger.info( + f" Master weight dtype: {self.transformer.parameters().__next__().dtype}" + ) + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + assert NotImplementedError( + "resume_from_checkpoint is not supported now.") + # TODO + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=init_steps, + desc="Steps", + # Only show the progress bar once on each machine. + disable=self.local_rank > 0, + ) + + loader_iter = iter(train_dataloader) + + step_times = deque(maxlen=100) + + # todo future + for i in range(init_steps): + next(loader_iter) + # get gpu memory usage + gpu_memory_usage = torch.cuda.memory_allocated() / 1024**2 + logger.info( + f"GPU memory usage before train_one_step: {gpu_memory_usage} MB") + + for step in range(init_steps + 1, args.max_train_steps + 1): + start_time = time.perf_counter() + + loss, grad_norm = self.train_one_step( + self.transformer, + # args.model_type, + "wan", + optimizer, + lr_scheduler, + loader_iter, + noise_scheduler, + noise_random_generator, + args.gradient_accumulation_steps, + args.sp_size, + args.precondition_outputs, + args.max_grad_norm, + args.weighting_scheme, + args.logit_mean, + args.logit_std, + args.mode_scale, + ) + gpu_memory_usage = torch.cuda.memory_allocated() / 1024**2 + logger.info( + f"GPU memory usage after train_one_step: {gpu_memory_usage} MB") + + step_time = time.perf_counter() - start_time + step_times.append(step_time) + avg_step_time = sum(step_times) / len(step_times) + + # Manual gradient checking - only at first step + if step == 1 and ENABLE_GRADIENT_CHECK: + logger.info(f"Performing gradient check at step {step}") + self.setup_gradient_check(args, loader_iter, noise_scheduler, + noise_random_generator) + + progress_bar.set_postfix({ + "loss": f"{loss:.4f}", + "step_time": f"{step_time:.2f}s", + "grad_norm": grad_norm, + }) + progress_bar.update(1) + if self.rank <= 0: + wandb.log( + { + "train_loss": loss, + "learning_rate": lr_scheduler.get_last_lr()[0], + "step_time": step_time, + "avg_step_time": avg_step_time, + "grad_norm": grad_norm, + }, + step=step, + ) + if step % args.checkpointing_steps == 0: + if args.use_lora: + raise NotImplementedError("LoRA is not supported now") + # Save LoRA weights + # save_lora_checkpoint(transformer, optimizer, rank, + # args.output_dir, step, pipe) + else: + # Your existing checkpoint saving code + save_checkpoint_v1(self.transformer, self.rank, + args.output_dir, step) + self.transformer.train() + self.sp_group.barrier() + if args.log_validation and step % args.validation_steps == 0: + self.log_validation(self.transformer, args, step) + + if args.use_lora: + raise NotImplementedError("LoRA is not supported now") + # save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipe) + else: + save_checkpoint_v1(self.transformer, self.rank, args.output_dir, + args.max_train_steps) + + if get_sp_group(): + cleanup_dist_env_and_memory() + + +def main(args): + logger.info("Starting training pipeline...") + + pipeline = WanTrainingPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + args = pipeline.fastvideo_args + pipeline.forward(None, args) + logger.info("Training pipeline done") + + +if __name__ == "__main__": + argv = sys.argv + from fastvideo.v1.fastvideo_args import TrainingArgs + from fastvideo.v1.utils import FlexibleArgumentParser + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + args.use_cpu_offload = False + print(args) + main(args) diff --git a/scripts/distill/distill_v1.sh b/scripts/distill/distill_v1.sh new file mode 100644 index 000000000..328d84c9b --- /dev/null +++ b/scripts/distill/distill_v1.sh @@ -0,0 +1,47 @@ +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online + +DATA_DIR=data/HD-Mixkit-Finetune-Wan/combined_parquet_dataset +VALIDATION_DIR=data/HD-Mixkit-Finetune-Wan/validation_parquet_dataset +num_gpus=1 +# IP=[MASTER NODE IP] + +# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t + # --gradient_checkpointing\ + # --pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo \ + # --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ +torchrun --nnodes 1 --nproc_per_node $num_gpus\ + fastvideo/v1/pipelines/distillation_pipeline.py\ + --mode distill\ + --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --cache_dir "/home/test/.cache"\ + --data_path "$DATA_DIR"\ + --validation_prompt_dir "$VALIDATION_DIR"\ + --train_batch_size=1 \ + --num_latent_t 1 \ + --sp_size $num_gpus \ + --train_sp_batch_size 1\ + --dataloader_num_workers $num_gpus\ + --gradient_accumulation_steps=1\ + --max_train_steps=320\ + --learning_rate=1e-6\ + --mixed_precision="bf16"\ + --checkpointing_steps=64\ + --validation_steps 20\ + --validation_sampling_steps "2,4,8" \ + --checkpoints_total_limit 3\ + --allow_tf32\ + --ema_start_step 0\ + --cfg 0.0\ + --log_validation\ + --output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_16_HD"\ + --tracker_project_name Hunyuan_Distill \ + --num_height 720 \ + --num_width 1280 \ + --num_frames 125 \ + --shift 17 \ + --validation_guidance_scale "1.0" \ + --num_euler_timesteps 50 \ + --multi_phased_distill_schedule "4000-1" \ + --not_apply_cfg_solver \ + --master_weight_type "bf16" \ No newline at end of file diff --git a/scripts/finetune/finetune_v1.sh b/scripts/finetune/finetune_v1.sh index b5fb9ec58..811f13b27 100644 --- a/scripts/finetune/finetune_v1.sh +++ b/scripts/finetune/finetune_v1.sh @@ -12,7 +12,7 @@ NUM_GPUS=4 torchrun --nnodes 1 --nproc_per_node $NUM_GPUS\ fastvideo/v1/training/wan_training_pipeline.py\ --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --inference_mode False\ + --mode training\ --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ --cache_dir "/home/ray/.cache"\ --data_path "$DATA_DIR"\ From e1a27431f103892da4b37c09f1045ad700683407 Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Mon, 2 Jun 2025 02:10:36 +0800 Subject: [PATCH 03/14] [WIP][Feat] distill run in single gpu --- fastvideo/distill/solver.py | 20 +- fastvideo/v1/dataset/parquet_datasets.py | 471 ++++++++++ fastvideo/v1/fastvideo_args.py | 2 + .../v1/models/loader/component_loader.py | 1 + .../v1/pipelines/composed_pipeline_base.py | 11 +- fastvideo/v1/pipelines/training_pipeline.py | 820 ------------------ .../distillation_pipeline.py | 250 +++--- scripts/distill/distill_v1.sh | 12 +- 8 files changed, 657 insertions(+), 930 deletions(-) create mode 100644 fastvideo/v1/dataset/parquet_datasets.py delete mode 100644 fastvideo/v1/pipelines/training_pipeline.py rename fastvideo/v1/{pipelines => training}/distillation_pipeline.py (75%) mode change 100644 => 100755 scripts/distill/distill_v1.sh diff --git a/fastvideo/distill/solver.py b/fastvideo/distill/solver.py index 7f8fec847..d7c89b04a 100644 --- a/fastvideo/distill/solver.py +++ b/fastvideo/distill/solver.py @@ -7,11 +7,27 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import BaseOutput, logging -from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule - logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + + @dataclass class PCMFMSchedulerOutput(BaseOutput): prev_sample: torch.FloatTensor diff --git a/fastvideo/v1/dataset/parquet_datasets.py b/fastvideo/v1/dataset/parquet_datasets.py new file mode 100644 index 000000000..a19db5038 --- /dev/null +++ b/fastvideo/v1/dataset/parquet_datasets.py @@ -0,0 +1,471 @@ +import argparse +import json +import os +import random +import time +from collections import defaultdict +from typing import Any, Dict, List + +import numpy as np +import pyarrow.parquet as pq +import torch +import tqdm +from einops import rearrange +from torch import distributed as dist +from torch.utils.data import Dataset +from torchdata.stateful_dataloader import StatefulDataLoader + +from fastvideo.v1.distributed import (get_sp_group, get_sp_parallel_rank, + get_sp_world_size, get_world_rank, + get_world_size) +from fastvideo.v1.logger import init_logger + +logger = init_logger(__name__) + + +class ParquetVideoTextDataset(Dataset): + """Efficient loader for video-text data from a directory of Parquet files.""" + + def __init__(self, + path: str, + batch_size, + cfg_rate: float = 0.0, + num_latent_t: int = 2, + seed: int = 0, + validation: bool = False): + super().__init__() + self.path = str(path) + self.batch_size = batch_size + self.global_rank = get_world_rank() + self.rank_in_sp_group = get_sp_parallel_rank() + self.sp_group = get_sp_group() + self.sp_world_size = get_sp_world_size() + self.world_size = get_world_size() + self.cfg_rate = cfg_rate + self.num_latent_t = num_latent_t + self.local_indices = None + self.validation = validation + + # Negative prompt caching + self.neg_metadata = None + self.cached_neg_prompt: Dict[str, Any] | None = None + + self.plan_output_dir = os.path.join( + self.path, + f"data_plan_world_size_{self.world_size}_sp_size_{self.sp_world_size}.json" + ) + + # group_ranks: a list of lists + # len(group_ranks) = self.world_size + # len(group_ranks[i]) = self.sp_world_size + # group_ranks[i] represents the ranks of the SP group for the i-th GPU + # For example, if self.world_size = 4, self.sp_world_size = 2, then + # group_ranks = [[0, 1], [0, 1], [2, 3], [2, 3]] + sp_group_ranks = get_sp_group().ranks + group_ranks: List[List] = [[] for _ in range(self.world_size)] + dist.all_gather_object(group_ranks, sp_group_ranks) + + if self.global_rank == 0: + # If a plan already exists, then skip creating a new plan + # This will be useful when resume training + if os.path.exists(self.plan_output_dir): + logger.info("Using existing plan from %s", self.plan_output_dir) + else: + logger.info("Creating new plan for %s", self.plan_output_dir) + metadatas = [] + for root, _, files in os.walk(self.path): + for file in sorted(files): + if file.endswith('.parquet'): + file_path = os.path.join(root, file) + num_rows = pq.ParquetFile( + file_path).metadata.num_rows + for row_idx in range(num_rows): + metadatas.append((file_path, row_idx)) + + # the negative prompt is always the first row in the first + # parquet file + if validation: + self.neg_metadata = metadatas[0] + metadatas = metadatas[1:] + + # Generate the plan that distribute rows among workers + random.seed(seed) + random.shuffle(metadatas) + + # Get all sp groups + # e.g. if num_gpus = 4, sp_size = 2 + # group_ranks = [(0, 1), (0, 1), (2, 3), (2, 3)] + # We will assign the same batches of data to ranks in the same sp group, and we'll assign different batches to ranks in different sp groups + # e.g. plan = {0: [row 1, row 4], 1: [row 1, row 4], 2: [row 2, row 3], 3: [row 2, row 3]} + group_ranks_list: List[Any] = list( + set(tuple(r) for r in group_ranks)) + num_sp_groups = len(group_ranks_list) + plan = defaultdict(list) + for idx, metadata in enumerate(metadatas): + sp_group_idx = idx % num_sp_groups + for global_rank in group_ranks_list[sp_group_idx]: + plan[global_rank].append(metadata) + + if validation: + assert self.neg_metadata is not None + plan["negative_prompt"] = [self.neg_metadata] + with open(self.plan_output_dir, "w") as f: + json.dump(plan, f) + else: + pass + dist.barrier() + if validation: + with open(self.plan_output_dir) as f: + plan = json.load(f) + self.neg_metadata = plan["negative_prompt"][0] + + self.uncond_prompt_embed = torch.zeros(512, 4096).to(torch.float32) + self.uncond_prompt_mask = torch.zeros(1, 512).bool() + + def _load_and_cache_negative_prompt(self) -> None: + """Load and cache the negative prompt. Only rank 0 in each SP group should call this.""" + if not self.validation or self.neg_metadata is None: + return + + if self.cached_neg_prompt is not None: + return + + # Only rank 0 in each SP group should read the negative prompt + try: + file_path, row_idx = self.neg_metadata + parquet_file = pq.ParquetFile(file_path) + + # Since negative prompt is always the first row (row_idx = 0), + # it's always in the first row group + row_group_index = 0 + local_index = row_idx # This will be 0 for the negative prompt + + row_group = parquet_file.read_row_group(row_group_index).to_pydict() + row_dict = {k: v[local_index] for k, v in row_group.items()} + del row_group + + # Process the negative prompt row + self.cached_neg_prompt = self._process_row(row_dict) + + except Exception as e: + logger.error("Failed to load negative prompt: %s", e) + self.cached_neg_prompt = None + + def get_validation_negative_prompt( + self + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: + """ + Get the negative prompt for validation. + This method ensures the negative prompt is loaded and cached properly. + Returns the processed negative prompt data (latents, embeddings, masks, info). + """ + if not self.validation: + raise ValueError( + "get_validation_negative_prompt() can only be called in validation mode" + ) + + # Load and cache if needed (only rank 0 in SP group will actually load) + if self.cached_neg_prompt is None: + self._load_and_cache_negative_prompt() + + if self.cached_neg_prompt is None: + raise RuntimeError( + f"Rank {self.global_rank} (SP rank {self.rank_in_sp_group}): Could not retrieve negative prompt data" + ) + + # Extract the components + lat, emb, mask, info = (self.cached_neg_prompt["latents"], + self.cached_neg_prompt["embeddings"], + self.cached_neg_prompt["masks"], + self.cached_neg_prompt["info"]) + + # Apply the same processing as in __getitem__ + if lat.numel() == 0: # Validation parquet + return lat, emb, mask, info + else: + lat = lat[:, -self.num_latent_t:] + if self.sp_world_size > 1: + lat = rearrange(lat, + "t (n s) h w -> t n s h w", + n=self.sp_world_size).contiguous() + lat = lat[:, self.rank_in_sp_group, :, :, :] + return lat, emb, mask, info + + + def __len__(self): + if self.local_indices is None: + try: + with open(self.plan_output_dir) as f: + plan = json.load(f) + self.local_indices = plan[str(self.global_rank)] + except Exception as err: + raise Exception( + "The data plan hasn't been created yet") from err + assert self.local_indices is not None + return len(self.local_indices) + + def __getitem__(self, idx): + if self.local_indices is None: + try: + with open(self.plan_output_dir) as f: + plan = json.load(f) + self.local_indices = plan[self.global_rank] + except Exception as err: + raise Exception( + "The data plan hasn't been created yet") from err + assert self.local_indices is not None + file_path, row_idx = self.local_indices[idx] + parquet_file = pq.ParquetFile(file_path) + + # Calculate the row group to read into memory and the local idx + # This way we can avoid reading in the entire parquet file + cumulative = 0 + for i in range(parquet_file.num_row_groups): + num_rows = parquet_file.metadata.row_group(i).num_rows + if cumulative + num_rows > row_idx: + row_group_index = i + local_index = row_idx - cumulative + break + cumulative += num_rows + + row_group = parquet_file.read_row_group(row_group_index).to_pydict() + row_dict = {k: v[local_index] for k, v in row_group.items()} + del row_group + + processed = self._process_row(row_dict) + lat, emb, mask, info = processed["latents"], processed[ + "embeddings"], processed["masks"], processed["info"] + if lat.numel() == 0: # Validation parquet + return lat, emb, mask, info + else: + lat = lat[:, -self.num_latent_t:] + if self.sp_world_size > 1: + lat = rearrange(lat, + "t (n s) h w -> t n s h w", + n=self.sp_world_size).contiguous() + lat = lat[:, self.rank_in_sp_group, :, :, :] + return lat, emb, mask, info + + def _process_row(self, row) -> Dict[str, Any]: + """Process a PyArrow batch into tensors.""" + + vae_latent_bytes = row["vae_latent_bytes"] + vae_latent_shape = row["vae_latent_shape"] + text_embedding_bytes = row["text_embedding_bytes"] + text_embedding_shape = row["text_embedding_shape"] + text_attention_mask_bytes = row["text_attention_mask_bytes"] + text_attention_mask_shape = row["text_attention_mask_shape"] + + # Process latent + if not vae_latent_shape: # No VAE latent is stored. Split is validation + lat = np.array([]) + else: + lat = np.frombuffer(vae_latent_bytes, + dtype=np.float32).reshape(vae_latent_shape) + # Make array writable + lat = np.copy(lat) + + if random.random() < self.cfg_rate: + emb = np.zeros((512, 4096), dtype=np.float32) + else: + emb = np.frombuffer(text_embedding_bytes, + dtype=np.float32).reshape(text_embedding_shape) + # Make array writable + emb = np.copy(emb) + if emb.shape[0] < 512: + padded_emb = np.zeros((512, emb.shape[1]), dtype=np.float32) + padded_emb[:emb.shape[0], :] = emb + emb = padded_emb + elif emb.shape[0] > 512: + emb = emb[:512, :] + + # Process mask + if len(text_attention_mask_bytes) > 0 and len( + text_attention_mask_shape) > 0: + msk = np.frombuffer(text_attention_mask_bytes, + dtype=np.uint8).astype(np.bool_) + msk = msk.reshape(1, -1) + # Make array writable + msk = np.copy(msk) + if msk.shape[1] < 512: + padded_msk = np.zeros((1, 512), dtype=np.bool_) + padded_msk[:, :msk.shape[1]] = msk + msk = padded_msk + elif msk.shape[1] > 512: + msk = msk[:, :512] + else: + msk = np.ones((1, 512), dtype=np.bool_) + + # Collect metadata + info = { + "width": row["width"], + "height": row["height"], + "num_frames": row["num_frames"], + "duration_sec": row["duration_sec"], + "fps": row["fps"], + "file_name": row["file_name"], + "caption": row["caption"], + } + + return { + "latents": torch.from_numpy(lat), + "embeddings": torch.from_numpy(emb), + "masks": torch.from_numpy(msk), + "info": info + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Benchmark Parquet dataset loading speed') + parser.add_argument('--path', + type=str, + default="your/dataset/path", + help='Path to Parquet dataset') + parser.add_argument('--batch_size', + type=int, + default=4, + help='Batch size for DataLoader') + parser.add_argument('--num_batches', + type=int, + default=100, + help='Number of batches to benchmark') + parser.add_argument('--vae_debug', action="store_true") + args = parser.parse_args() + + # Initialize distributed training + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", 0)) + + # Initialize CUDA device first + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + else: + device = torch.device("cpu") + + # Initialize distributed training + if world_size > 1: + dist.init_process_group(backend="nccl", + init_method="env://", + world_size=world_size, + rank=rank) + print( + f"Initialized process: rank={rank}, local_rank={local_rank}, world_size={world_size}, device={device}" + ) + + # Create dataset + dataset = ParquetVideoTextDataset( + args.path, + batch_size=args.batch_size, + ) + + # Create DataLoader with proper settings + dataloader = StatefulDataLoader( + dataset, + batch_size=args.batch_size, + num_workers=1, # Reduce number of workers to avoid memory issues + prefetch_factor=2, + shuffle=False, + pin_memory=True, + drop_last=True) + + # Example of how to load dataloader state + # if os.path.exists("/workspace/FastVideo/dataloader_state.pt"): + # dataloader_state = torch.load("/workspace/FastVideo/dataloader_state.pt") + # dataloader.load_state_dict(dataloader_state[rank]) + + # Warm-up with synchronization + if rank == 0: + print("Warming up...") + for i, (latents, embeddings, masks, infos) in enumerate(dataloader): + # Example of how to save dataloader state + # if i == 30: + # dist.barrier() + # local_data = {rank: dataloader.state_dict()} + # gathered_data = [None] * world_size + # dist.all_gather_object(gathered_data, local_data) + # if rank == 0: + # global_state_dict = {} + # for d in gathered_data: + # global_state_dict.update(d) + # torch.save(global_state_dict, "dataloader_state.pt") + assert torch.sum(masks[0]).item() == torch.count_nonzero( + embeddings[0]).item() // 4096 + if args.vae_debug: + from diffusers.utils import export_to_video + from diffusers.video_processor import VideoProcessor + + from fastvideo.v1.configs.models.vaes import WanVAEConfig + from fastvideo.v1.fastvideo_args import FastVideoArgs + from fastvideo.v1.models.loader.component_loader import VAELoader + VAE_PATH = "/workspace/data/Wan-AI/Wan2.1-T2V-1.3B-Diffusers/vae" + fastvideo_args = FastVideoArgs( + model_path=VAE_PATH, + vae_config=WanVAEConfig(load_encoder=False), + vae_precision="fp32") + fastvideo_args.device = device + vae_loader = VAELoader() + vae = vae_loader.load(model_path=VAE_PATH, + architecture="", + fastvideo_args=fastvideo_args) + + videoprocessor = VideoProcessor(vae_scale_factor=8) + + with torch.inference_mode(): + video = vae.decode(latents[0].unsqueeze(0).to(device)) + video = videoprocessor.postprocess_video(video) + video_path = os.path.join("/workspace/FastVideo/debug_videos", + infos["caption"][0][:50] + ".mp4") + export_to_video(video[0], video_path, fps=16) + + # Move data to device + # latents = latents.to(device) + # embeddings = embeddings.to(device) + + if world_size > 1: + dist.barrier() + + # Benchmark + if rank == 0: + print(f"Benchmarking with batch_size={args.batch_size}") + start_time = time.time() + total_samples = 0 + for i, (latents, embeddings, masks, + infos) in enumerate(tqdm.tqdm(dataloader, total=args.num_batches)): + if i >= args.num_batches: + break + + # Move data to device + latents = latents.to(device) + embeddings = embeddings.to(device) + + # Calculate actual batch size + batch_size = latents.size(0) + total_samples += batch_size + + # Print progress only from rank 0 + if rank == 0 and (i + 1) % 10 == 0: + elapsed = time.time() - start_time + samples_per_sec = total_samples / elapsed + print( + f"Batch {i+1}/{args.num_batches}, Speed: {samples_per_sec:.2f} samples/sec" + ) + + # Final statistics + if world_size > 1: + dist.barrier() + + if rank == 0: + elapsed = time.time() - start_time + samples_per_sec = total_samples / elapsed + + print("\nBenchmark Results:") + print(f"Total time: {elapsed:.2f} seconds") + print(f"Total samples: {total_samples}") + print(f"Average speed: {samples_per_sec:.2f} samples/sec") + print(f"Time per batch: {elapsed/args.num_batches*1000:.2f} ms") + + if world_size > 1: + dist.destroy_process_group() diff --git a/fastvideo/v1/fastvideo_args.py b/fastvideo/v1/fastvideo_args.py index d83ce5642..f010f3285 100644 --- a/fastvideo/v1/fastvideo_args.py +++ b/fastvideo/v1/fastvideo_args.py @@ -582,6 +582,8 @@ class TrainingArgs(FastVideoArgs): pred_decay_type: str = "" hunyuan_teacher_disable_cfg: bool = False + use_lora: bool = False + # master_weight_type master_weight_type: str = "" diff --git a/fastvideo/v1/models/loader/component_loader.py b/fastvideo/v1/models/loader/component_loader.py index 110b55875..c8908420b 100644 --- a/fastvideo/v1/models/loader/component_loader.py +++ b/fastvideo/v1/models/loader/component_loader.py @@ -367,6 +367,7 @@ class TransformerLoader(ComponentLoader): def load(self, model_path: str, architecture: str, fastvideo_args: FastVideoArgs): """Load the transformer based on the model path, architecture, and inference args.""" + print(f"Loading transformer from {model_path}") config = get_diffusers_config(model=model_path) hf_config = deepcopy(config) cls_name = config.pop("_class_name") diff --git a/fastvideo/v1/pipelines/composed_pipeline_base.py b/fastvideo/v1/pipelines/composed_pipeline_base.py index 3d3698e87..29190352e 100644 --- a/fastvideo/v1/pipelines/composed_pipeline_base.py +++ b/fastvideo/v1/pipelines/composed_pipeline_base.py @@ -55,7 +55,7 @@ def __init__(self, use. The pipeline should be stateless and not hold any batch state. """ - if fastvideo_args.training_mode: + if fastvideo_args.training_mode or fastvideo_args.distill_mode: assert isinstance(fastvideo_args, TrainingArgs) self.training_args = fastvideo_args assert self.training_args is not None @@ -94,11 +94,12 @@ def __init__(self, self.initialize_validation_pipeline(self.training_args) self.initialize_training_pipeline(self.training_args) + # TODO(jinzhe): discuss this if fastvideo_args.distill_mode: - self.initialize_distillation_pipeline(fastvideo_args) - - if fastvideo_args.log_validation: - self.initialize_validation_pipeline(fastvideo_args) + assert self.training_args is not None + if self.training_args.log_validation: + self.initialize_validation_pipeline(self.training_args) + self.initialize_distillation_pipeline(self.training_args) self.initialize_pipeline(fastvideo_args) diff --git a/fastvideo/v1/pipelines/training_pipeline.py b/fastvideo/v1/pipelines/training_pipeline.py deleted file mode 100644 index a8bffa368..000000000 --- a/fastvideo/v1/pipelines/training_pipeline.py +++ /dev/null @@ -1,820 +0,0 @@ -import gc -import os -import sys -import time -import traceback -from abc import ABC, abstractmethod -from collections import deque -from copy import deepcopy - -import imageio -import numpy as np -import torch -import torchvision -# import torch.distributed as dist -import wandb -from diffusers.optimization import get_scheduler -from einops import rearrange -from torchdata.stateful_dataloader import StatefulDataLoader -from tqdm.auto import tqdm - -from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input -from fastvideo.utils.checkpoint import save_checkpoint_v1 -from fastvideo.v1.configs.sample import SamplingParam -from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset -from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group -from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs -from fastvideo.v1.forward_context import set_forward_context -from fastvideo.v1.logger import init_logger -from fastvideo.v1.pipelines import ComposedPipelineBase -from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch -from fastvideo.v1.pipelines.training_utils import ( - _clip_grad_norm_while_handling_failing_dtensor_cases, - compute_density_for_timestep_sampling, get_sigmas) -from fastvideo.v1.pipelines.wan.wan_pipeline import WanValidationPipeline - -logger = init_logger(__name__) - -# Manual gradient checking flag - set to True to enable gradient verification -ENABLE_GRADIENT_CHECK = False -GRADIENT_CHECK_DTYPE = torch.bfloat16 - - -class TrainingPipeline(ComposedPipelineBase, ABC): - """ - A pipeline for training a model. All training pipelines should inherit from this class. - All reusable components and code should be implemented in this class. - """ - _required_config_modules = ["scheduler", "transformer"] - - def initialize_training_pipeline(self, fastvideo_args: TrainingArgs): - logger.info("Initializing training pipeline...") - self.device = fastvideo_args.device - self.sp_group = get_sp_group() - self.world_size = self.sp_group.world_size - self.rank = self.sp_group.rank - self.local_rank = self.sp_group.local_rank - self.transformer = self.get_module("transformer") - assert self.transformer is not None - - self.transformer.requires_grad_(True) - self.transformer.train() - - args = fastvideo_args - - noise_scheduler = self.modules["scheduler"] - params_to_optimize = self.transformer.parameters() - params_to_optimize = list( - filter(lambda p: p.requires_grad, params_to_optimize)) - - optimizer = torch.optim.AdamW( - params_to_optimize, - lr=args.learning_rate, - betas=(0.9, 0.999), - weight_decay=args.weight_decay, - eps=1e-8, - ) - - init_steps = 0 - logger.info("optimizer: %s", optimizer) - - # todo add lr scheduler - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * self.world_size, - num_training_steps=args.max_train_steps * self.world_size, - num_cycles=args.lr_num_cycles, - power=args.lr_power, - last_epoch=init_steps - 1, - ) - - train_dataset = ParquetVideoTextDataset( - args.data_path, - batch_size=args.train_batch_size, - rank=self.rank, - world_size=self.world_size, - cfg_rate=args.cfg, - num_latent_t=args.num_latent_t) - - train_dataloader = StatefulDataLoader( - train_dataset, - batch_size=args.train_batch_size, - num_workers=args. - dataloader_num_workers, # Reduce number of workers to avoid memory issues - prefetch_factor=2, - shuffle=False, - pin_memory=True, - drop_last=True) - - self.lr_scheduler = lr_scheduler - self.train_dataset = train_dataset - self.train_dataloader = train_dataloader - self.init_steps = init_steps - self.optimizer = optimizer - self.noise_scheduler = noise_scheduler - # self.noise_random_generator = noise_random_generator - - # num_update_steps_per_epoch = math.ceil( - # len(train_dataloader) / args.gradient_accumulation_steps * - # args.sp_size / args.train_sp_batch_size) - # args.num_train_epochs = math.ceil(args.max_train_steps / - # num_update_steps_per_epoch) - - if self.rank <= 0: - project = args.tracker_project_name or "fastvideo" - wandb.init(project=project, config=args) - - @abstractmethod - def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): - raise NotImplementedError( - "Training pipelines must implement this method") - - @abstractmethod - def train_one_step(self, transformer, model_type, optimizer, lr_scheduler, - loader, noise_scheduler, noise_random_generator, - gradient_accumulation_steps, sp_size, - precondition_outputs, max_grad_norm, weighting_scheme, - logit_mean, logit_std, mode_scale): - """ - Train one step of the model. - """ - raise NotImplementedError( - "Training pipeline must implement this method") - - def log_validation(self, transformer, fastvideo_args, global_step): - fastvideo_args.inference_mode = True - fastvideo_args.use_cpu_offload = False - if not fastvideo_args.log_validation: - return - if self.validation_pipeline is None: - raise ValueError("Validation pipeline is not set") - - # Create sampling parameters if not provided - sampling_param = SamplingParam.from_pretrained( - fastvideo_args.model_path) - - # Prepare validation prompts - print('fastvideo_args.validation_prompt_dir', - fastvideo_args.validation_prompt_dir) - validation_dataset = ParquetVideoTextDataset( - fastvideo_args.validation_prompt_dir, - batch_size=1, - rank=0, - world_size=1, - cfg_rate=0, - num_latent_t=args.num_latent_t) - - validation_dataloader = StatefulDataLoader( - validation_dataset, - batch_size=1, - num_workers=1, # Reduce number of workers to avoid memory issues - prefetch_factor=2, - shuffle=False, - pin_memory=True, - drop_last=False) - - transformer.requires_grad_(False) - for p in transformer.parameters(): - p.requires_grad = False - transformer.eval() - - # Add the transformer to the validation pipeline - self.validation_pipeline.add_module("transformer", transformer) - self.validation_pipeline.latent_preparation_stage.transformer = transformer - self.validation_pipeline.denoising_stage.transformer = transformer - - # Process each validation prompt - videos = [] - captions = [] - for _, embeddings, masks, infos in validation_dataloader: - logger.info(f"infos: {infos}") - caption = infos['caption'] - captions.append(caption) - prompt_embeds = embeddings.to(fastvideo_args.device) - prompt_attention_mask = masks.to(fastvideo_args.device) - - # Calculate sizes - latents_size = [(sampling_param.num_frames - 1) // 4 + 1, - sampling_param.height // 8, - sampling_param.width // 8] - n_tokens = latents_size[0] * latents_size[1] * latents_size[2] - - # Prepare batch for validation - # print('shape of embeddings', prompt_embeds.shape) - batch = ForwardBatch( - # **shallow_asdict(sampling_param), - data_type="video", - latents=None, - # seed=sampling_param.seed, - # data_type="video", - prompt_embeds=[prompt_embeds], - prompt_attention_mask=[prompt_attention_mask], - # make sure we use the same height, width, and num_frames as the training pipeline - height=args.num_height, - width=args.num_width, - num_frames=args.num_frames, - # num_inference_steps=fastvideo_args.validation_sampling_steps, - num_inference_steps=10, - # guidance_scale=fastvideo_args.validation_guidance_scale, - guidance_scale=1, - n_tokens=n_tokens, - do_classifier_free_guidance=False, - eta=0.0, - extra={}, - ) - - # Run validation inference - with torch.inference_mode(): - output_batch = self.validation_pipeline.forward( - batch, fastvideo_args) - samples = output_batch.output - - # Process outputs - video = rearrange(samples, "b c t h w -> t b c h w") - frames = [] - for x in video: - x = torchvision.utils.make_grid(x, nrow=6) - x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) - frames.append((x * 255).numpy().astype(np.uint8)) - videos.append(frames) - - # Log validation results - rank = int(os.environ.get("RANK", 0)) - - if rank == 0: - video_filenames = [] - video_captions = [] - for i, video in enumerate(videos): - caption = captions[i] - filename = os.path.join( - fastvideo_args.output_dir, - f"validation_step_{global_step}_video_{i}.mp4") - imageio.mimsave(filename, video, fps=sampling_param.fps) - video_filenames.append(filename) - video_captions.append( - caption) # Store the caption for each video - - logs = { - "validation_videos": [ - wandb.Video(filename, - caption=caption) for filename, caption in zip( - video_filenames, video_captions) - ] - } - wandb.log(logs, step=global_step) - - # Re-enable gradients for training - transformer.requires_grad_(True) - transformer.train() - - gc.collect() - torch.cuda.empty_cache() - - def gradient_check_parameters(self, - transformer, - latents, - encoder_hidden_states, - encoder_attention_mask, - timesteps, - target, - eps=5e-2, - max_params_to_check=2000): - """ - Verify gradients using finite differences for FSDP models with GRADIENT_CHECK_DTYPE. - Uses standard tolerances for GRADIENT_CHECK_DTYPE precision. - """ - # Move all inputs to CPU and clear GPU memory - inputs_cpu = { - 'latents': latents.cpu(), - 'encoder_hidden_states': encoder_hidden_states.cpu(), - 'encoder_attention_mask': encoder_attention_mask.cpu(), - 'timesteps': timesteps.cpu(), - 'target': target.cpu() - } - del latents, encoder_hidden_states, encoder_attention_mask, timesteps, target - torch.cuda.empty_cache() - - def compute_loss(): - # Move inputs to GPU, compute loss, cleanup - inputs_gpu = { - k: - v.to(self.fastvideo_args.device, - dtype=GRADIENT_CHECK_DTYPE - if k != 'encoder_attention_mask' else None) - for k, v in inputs_cpu.items() - } - - # Use GRADIENT_CHECK_DTYPE for more accurate gradient checking - # with torch.autocast(enabled=False, device_type="cuda"): - with torch.autocast("cuda", dtype=GRADIENT_CHECK_DTYPE): - with set_forward_context( - current_timestep=inputs_gpu['timesteps'], - attn_metadata=None): - model_pred = transformer( - hidden_states=inputs_gpu['latents'], - encoder_hidden_states=inputs_gpu[ - 'encoder_hidden_states'], - timestep=inputs_gpu['timesteps'], - encoder_attention_mask=inputs_gpu[ - 'encoder_attention_mask'], - return_dict=False)[0] - - if self.fastvideo_args.precondition_outputs: - sigmas = get_sigmas(self.noise_scheduler, - inputs_gpu['latents'].device, - inputs_gpu['timesteps'], - n_dim=inputs_gpu['latents'].ndim, - dtype=inputs_gpu['latents'].dtype) - model_pred = inputs_gpu['latents'] - model_pred * sigmas - target_adjusted = inputs_gpu['target'] - else: - target_adjusted = inputs_gpu['target'] - - loss = torch.mean((model_pred - target_adjusted)**2) - - # Cleanup and return - loss_cpu = loss.cpu() - del inputs_gpu, model_pred, target_adjusted - if 'sigmas' in locals(): del sigmas - torch.cuda.empty_cache() - return loss_cpu.to(self.fastvideo_args.device) - - try: - # Get analytical gradients - transformer.zero_grad() - analytical_loss = compute_loss() - analytical_loss.backward() - - # Check gradients for selected parameters - absolute_errors = [] - param_count = 0 - - for name, param in transformer.named_parameters(): - if not (param.requires_grad and param.grad is not None - and param_count < max_params_to_check - and param.grad.abs().max() > 5e-4): - continue - - # Get local parameter and gradient tensors - local_param = param._local_tensor if hasattr( - param, '_local_tensor') else param - local_grad = param.grad._local_tensor if hasattr( - param.grad, '_local_tensor') else param.grad - - # Find first significant gradient element - flat_param = local_param.data.view(-1) - flat_grad = local_grad.view(-1) - check_idx = next((i for i in range(min(10, flat_param.numel())) - if abs(flat_grad[i]) > 1e-4), 0) - - # Store original values - orig_value = flat_param[check_idx].item() - analytical_grad = flat_grad[check_idx].item() - - # Compute numerical gradient - for delta in [eps, -eps]: - with torch.no_grad(): - flat_param[check_idx] = orig_value + delta - loss = compute_loss() - if delta > 0: loss_plus = loss.item() - else: loss_minus = loss.item() - - # Restore parameter and compute error - with torch.no_grad(): - flat_param[check_idx] = orig_value - - numerical_grad = (loss_plus - loss_minus) / (2 * eps) - abs_error = abs(analytical_grad - numerical_grad) - rel_error = abs_error / max(abs(analytical_grad), - abs(numerical_grad), 1e-3) - absolute_errors.append(abs_error) - - logger.info( - f"{name}[{check_idx}]: analytical={analytical_grad:.6f}, " - f"numerical={numerical_grad:.6f}, abs_error={abs_error:.2e}, rel_error={rel_error:.2%}" - ) - - # param_count += 1 - - # Compute and log statistics - if absolute_errors: - min_err, max_err, mean_err = min(absolute_errors), max( - absolute_errors - ), sum(absolute_errors) / len(absolute_errors) - logger.info( - f"Gradient check stats: min={min_err:.2e}, max={max_err:.2e}, mean={mean_err:.2e}" - ) - - if self.rank <= 0: - wandb.log({ - "grad_check/min_abs_error": - min_err, - "grad_check/max_abs_error": - max_err, - "grad_check/mean_abs_error": - mean_err, - "grad_check/analytical_loss": - analytical_loss.item(), - }) - return max_err - - return float('inf') - - except Exception as e: - logger.error(f"Gradient check failed: {e}") - traceback.print_exc() - return float('inf') - - def setup_gradient_check(self, args, loader_iter, noise_scheduler, - noise_random_generator): - """ - Setup and perform gradient check on a fresh batch. - Args: - args: Training arguments - loader_iter: Data loader iterator - noise_scheduler: Noise scheduler for diffusion - noise_random_generator: Random number generator for noise - Returns: - float or None: Maximum gradient error or None if check is disabled/fails - """ - if not ENABLE_GRADIENT_CHECK: - return None - - try: - # Get a fresh batch and process it exactly like train_one_step - check_latents, check_encoder_hidden_states, check_encoder_attention_mask, check_infos = next( - loader_iter) - - # Process exactly like in train_one_step but use GRADIENT_CHECK_DTYPE - check_latents = check_latents.to(self.fastvideo_args.device, - dtype=GRADIENT_CHECK_DTYPE) - check_encoder_hidden_states = check_encoder_hidden_states.to( - self.fastvideo_args.device, dtype=GRADIENT_CHECK_DTYPE) - check_latents = normalize_dit_input("wan", check_latents) - batch_size = check_latents.shape[0] - check_noise = torch.randn_like(check_latents) - - check_u = compute_density_for_timestep_sampling( - weighting_scheme=args.weighting_scheme, - batch_size=batch_size, - generator=noise_random_generator, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - ) - check_indices = (check_u * - noise_scheduler.config.num_train_timesteps).long() - check_timesteps = noise_scheduler.timesteps[check_indices].to( - device=check_latents.device) - - check_sigmas = get_sigmas( - noise_scheduler, - check_latents.device, - check_timesteps, - n_dim=check_latents.ndim, - dtype=check_latents.dtype, - ) - check_noisy_model_input = ( - 1.0 - check_sigmas) * check_latents + check_sigmas * check_noise - - # Compute target exactly like train_one_step - if args.precondition_outputs: - check_target = check_latents - else: - check_target = check_noise - check_latents - - # Perform gradient check with the exact same inputs as training - max_grad_error = self.gradient_check_parameters( - transformer=self.transformer, - latents= - check_noisy_model_input, # Use noisy input like in training - encoder_hidden_states=check_encoder_hidden_states, - encoder_attention_mask=check_encoder_attention_mask, - timesteps=check_timesteps, - target=check_target, - max_params_to_check=100 # Check more parameters - ) - - if max_grad_error > 5e-2: - logger.error( - f"❌ Large gradient error detected: {max_grad_error:.2e}") - else: - logger.info( - f"✅ Gradient check passed: max error {max_grad_error:.2e}") - - return max_grad_error - - except Exception as e: - logger.error(f"Gradient check setup failed: {e}") - traceback.print_exc() - return None - - -class WanTrainingPipeline(TrainingPipeline): - """ - A training pipeline for Wan. - """ - _required_config_modules = ["scheduler", "transformer"] - - def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): - pass - - def create_training_stages(self, fastvideo_args: FastVideoArgs): - pass - - def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): - logger.info("Initializing validation pipeline...") - args_copy = deepcopy(fastvideo_args) - - args_copy.mode = "inference" - args_copy.vae_config.load_encoder = False - validation_pipeline = WanValidationPipeline.from_pretrained( - args.model_path, args=args_copy) - - self.validation_pipeline = validation_pipeline - - def train_one_step( - self, - transformer, - model_type, - optimizer, - lr_scheduler, - loader_iter, - noise_scheduler, - noise_random_generator, - gradient_accumulation_steps, - sp_size, - precondition_outputs, - max_grad_norm, - weighting_scheme, - logit_mean, - logit_std, - mode_scale, - ): - self.modules["transformer"].requires_grad_(True) - self.modules["transformer"].train() - - total_loss = 0.0 - optimizer.zero_grad() - for _ in range(gradient_accumulation_steps): - ( - latents, - encoder_hidden_states, - encoder_attention_mask, - infos, - ) = next(loader_iter) - latents = latents.to(self.fastvideo_args.device, - dtype=torch.bfloat16) - encoder_hidden_states = encoder_hidden_states.to( - self.fastvideo_args.device, dtype=torch.bfloat16) - latents = normalize_dit_input(model_type, latents) - batch_size = latents.shape[0] - noise = torch.randn_like(latents) - u = compute_density_for_timestep_sampling( - weighting_scheme=weighting_scheme, - batch_size=batch_size, - generator=noise_random_generator, - logit_mean=logit_mean, - logit_std=logit_std, - mode_scale=mode_scale, - ) - indices = (u * noise_scheduler.config.num_train_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to( - device=latents.device) - if sp_size > 1: - # Make sure that the timesteps are the same across all sp processes. - sp_group = get_sp_group() - sp_group.broadcast(timesteps, src=0) - sigmas = get_sigmas( - noise_scheduler, - latents.device, - timesteps, - n_dim=latents.ndim, - dtype=latents.dtype, - ) - noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise - print('device before forward ', - next(transformer.named_parameters())[1].device) - with torch.autocast("cuda", dtype=torch.bfloat16): - input_kwargs = { - "hidden_states": noisy_model_input, - "encoder_hidden_states": encoder_hidden_states, - "timestep": timesteps, - "encoder_attention_mask": encoder_attention_mask, # B, L - "return_dict": False, - } - if 'hunyuan' in model_type: - input_kwargs["guidance"] = torch.tensor( - [1000.0], - device=noisy_model_input.device, - dtype=torch.bfloat16) - with set_forward_context(current_timestep=timesteps, - attn_metadata=None): - model_pred = transformer(**input_kwargs)[0] - - if precondition_outputs: - model_pred = noisy_model_input - model_pred * sigmas - if precondition_outputs: - target = latents - else: - target = noise - latents - - loss = (torch.mean((model_pred.float() - target.float())**2) / - gradient_accumulation_steps) - print('device before backwardin context', - next(transformer.named_parameters())[1].device) - - print('device before backward out context', - next(transformer.named_parameters())[1].device) - loss.backward() - print('device after backward out context', - next(transformer.named_parameters())[1].device) - - avg_loss = loss.detach().clone() - sp_group = get_sp_group() - sp_group.all_reduce(avg_loss, op=torch.distributed.ReduceOp.AVG) - total_loss += avg_loss.item() - - model_parts = [self.transformer] - grad_norm = _clip_grad_norm_while_handling_failing_dtensor_cases( - [p for m in model_parts for p in m.parameters()], - max_grad_norm, - foreach=None, - ) - - optimizer.step() - print('device after optimizer step', - next(transformer.named_parameters())[1].device) - lr_scheduler.step() - print('device after scheduler step', - next(transformer.named_parameters())[1].device) - return total_loss, grad_norm.item() - - def forward( - self, - batch: ForwardBatch, - fastvideo_args: FastVideoArgs, - ): - args = fastvideo_args - self.fastvideo_args = args - train_dataloader = self.train_dataloader - init_steps = self.init_steps - lr_scheduler = self.lr_scheduler - optimizer = self.optimizer - noise_scheduler = self.noise_scheduler - noise_random_generator = None - - from diffusers import FlowMatchEulerDiscreteScheduler - noise_scheduler = FlowMatchEulerDiscreteScheduler() - - # Train! - total_batch_size = (self.world_size * args.gradient_accumulation_steps / - args.sp_size * args.train_sp_batch_size) - logger.info("***** Running training *****") - # logger.info(f" Num examples = {len(train_dataset)}") - # logger.info(f" Dataloader size = {len(train_dataloader)}") - # logger.info(f" Num Epochs = {args.num_train_epochs}") - logger.info(f" Resume training from step {init_steps}") - logger.info( - f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info( - f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" - ) - logger.info( - f" Gradient Accumulation steps = {args.gradient_accumulation_steps}" - ) - logger.info(f" Total optimization steps = {args.max_train_steps}") - logger.info( - f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B" - ) - # print dtype - logger.info( - f" Master weight dtype: {self.transformer.parameters().__next__().dtype}" - ) - - # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: - assert NotImplementedError( - "resume_from_checkpoint is not supported now.") - # TODO - - progress_bar = tqdm( - range(0, args.max_train_steps), - initial=init_steps, - desc="Steps", - # Only show the progress bar once on each machine. - disable=self.local_rank > 0, - ) - - loader_iter = iter(train_dataloader) - - step_times = deque(maxlen=100) - - # todo future - for i in range(init_steps): - next(loader_iter) - # get gpu memory usage - gpu_memory_usage = torch.cuda.memory_allocated() / 1024**2 - logger.info( - f"GPU memory usage before train_one_step: {gpu_memory_usage} MB") - - for step in range(init_steps + 1, args.max_train_steps + 1): - start_time = time.perf_counter() - - loss, grad_norm = self.train_one_step( - self.transformer, - # args.model_type, - "wan", - optimizer, - lr_scheduler, - loader_iter, - noise_scheduler, - noise_random_generator, - args.gradient_accumulation_steps, - args.sp_size, - args.precondition_outputs, - args.max_grad_norm, - args.weighting_scheme, - args.logit_mean, - args.logit_std, - args.mode_scale, - ) - gpu_memory_usage = torch.cuda.memory_allocated() / 1024**2 - logger.info( - f"GPU memory usage after train_one_step: {gpu_memory_usage} MB") - - step_time = time.perf_counter() - start_time - step_times.append(step_time) - avg_step_time = sum(step_times) / len(step_times) - - # Manual gradient checking - only at first step - if step == 1 and ENABLE_GRADIENT_CHECK: - logger.info(f"Performing gradient check at step {step}") - self.setup_gradient_check(args, loader_iter, noise_scheduler, - noise_random_generator) - - progress_bar.set_postfix({ - "loss": f"{loss:.4f}", - "step_time": f"{step_time:.2f}s", - "grad_norm": grad_norm, - }) - progress_bar.update(1) - if self.rank <= 0: - wandb.log( - { - "train_loss": loss, - "learning_rate": lr_scheduler.get_last_lr()[0], - "step_time": step_time, - "avg_step_time": avg_step_time, - "grad_norm": grad_norm, - }, - step=step, - ) - if step % args.checkpointing_steps == 0: - if args.use_lora: - raise NotImplementedError("LoRA is not supported now") - # Save LoRA weights - # save_lora_checkpoint(transformer, optimizer, rank, - # args.output_dir, step, pipe) - else: - # Your existing checkpoint saving code - save_checkpoint_v1(self.transformer, self.rank, - args.output_dir, step) - self.transformer.train() - self.sp_group.barrier() - if args.log_validation and step % args.validation_steps == 0: - self.log_validation(self.transformer, args, step) - - if args.use_lora: - raise NotImplementedError("LoRA is not supported now") - # save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps, pipe) - else: - save_checkpoint_v1(self.transformer, self.rank, args.output_dir, - args.max_train_steps) - - if get_sp_group(): - cleanup_dist_env_and_memory() - - -def main(args): - logger.info("Starting training pipeline...") - - pipeline = WanTrainingPipeline.from_pretrained( - args.pretrained_model_name_or_path, args=args) - args = pipeline.fastvideo_args - pipeline.forward(None, args) - logger.info("Training pipeline done") - - -if __name__ == "__main__": - argv = sys.argv - from fastvideo.v1.fastvideo_args import TrainingArgs - from fastvideo.v1.utils import FlexibleArgumentParser - parser = FlexibleArgumentParser() - parser = TrainingArgs.add_cli_args(parser) - parser = FastVideoArgs.add_cli_args(parser) - args = parser.parse_args() - args.use_cpu_offload = False - print(args) - main(args) diff --git a/fastvideo/v1/pipelines/distillation_pipeline.py b/fastvideo/v1/training/distillation_pipeline.py similarity index 75% rename from fastvideo/v1/pipelines/distillation_pipeline.py rename to fastvideo/v1/training/distillation_pipeline.py index d3e5537a1..ee5e6791d 100644 --- a/fastvideo/v1/pipelines/distillation_pipeline.py +++ b/fastvideo/v1/training/distillation_pipeline.py @@ -20,7 +20,6 @@ from fastvideo.distill.solver import EulerSolver, extract_into_tensor from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input -from fastvideo.utils.checkpoint import save_checkpoint_v1 from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group @@ -29,8 +28,10 @@ from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines import ComposedPipelineBase from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch -from fastvideo.v1.pipelines.training_utils import ( - _clip_grad_norm_while_handling_failing_dtensor_cases) +from fastvideo.v1.training.training_utils import ( + clip_grad_norm_while_handling_failing_dtensor_cases, + compute_density_for_timestep_sampling, get_sigmas, normalize_dit_input, + save_checkpoint) from fastvideo.v1.pipelines.wan.wan_pipeline import WanValidationPipeline logger = init_logger(__name__) @@ -111,28 +112,37 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): self.transformer.requires_grad_(True) self.transformer.train() - # Initialize teacher model - self.teacher_transformer = deepcopy(self.transformer) + # Initialize teacher model without deepcopy to avoid FSDP issues + logger.info("Creating teacher model...") + from fastvideo.v1.models.loader.component_loader import TransformerLoader + teacher_loader = TransformerLoader() + transformer_path = os.path.join(self.model_path, "transformer") + self.teacher_transformer = teacher_loader.load(transformer_path, "", fastvideo_args) self.teacher_transformer.requires_grad_(False) + self.teacher_transformer.eval() + logger.info("Teacher model initialized") # Initialize EMA model if needed if fastvideo_args.use_ema: - self.ema_transformer = deepcopy(self.transformer) + logger.info("Creating EMA model...") + ema_loader = TransformerLoader() + self.ema_transformer = ema_loader.load(transformer_path, "", fastvideo_args) self.ema_transformer.requires_grad_(False) + self.ema_transformer.eval() + logger.info("EMA model initialized") else: self.ema_transformer = None - args = fastvideo_args noise_scheduler = self.get_module("scheduler") assert noise_scheduler is not None # Initialize solver for distillation - if args.scheduler_type == "pcm_linear_quadratic": + if fastvideo_args.scheduler_type == "pcm_linear_quadratic": linear_steps = int(noise_scheduler.config.num_train_timesteps * - args.linear_range) + fastvideo_args.linear_range) sigmas = linear_quadratic_schedule( noise_scheduler.config.num_train_timesteps, - args.linear_quadratic_threshold, + fastvideo_args.linear_quadratic_threshold, linear_steps, ) sigmas = torch.tensor(sigmas).to(dtype=torch.float32) @@ -142,7 +152,7 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): self.solver = EulerSolver( sigmas.numpy()[::-1], noise_scheduler.config.num_train_timesteps, - euler_timesteps=args.num_euler_timesteps, + euler_timesteps=fastvideo_args.num_euler_timesteps, ) self.solver.to(self.device) @@ -153,9 +163,9 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): optimizer = torch.optim.AdamW( params_to_optimize, - lr=args.learning_rate, + lr=fastvideo_args.learning_rate, betas=(0.9, 0.999), - weight_decay=args.weight_decay, + weight_decay=fastvideo_args.weight_decay, eps=1e-8, ) @@ -164,28 +174,28 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): # Setup lr scheduler lr_scheduler = get_scheduler( - args.lr_scheduler, + fastvideo_args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * self.world_size, - num_training_steps=args.max_train_steps * self.world_size, - num_cycles=args.lr_num_cycles, - power=args.lr_power, + num_warmup_steps=fastvideo_args.lr_warmup_steps * self.world_size, + num_training_steps=fastvideo_args.max_train_steps * self.world_size, + num_cycles=fastvideo_args.lr_num_cycles, + power=fastvideo_args.lr_power, last_epoch=init_steps - 1, ) # Setup dataset train_dataset = ParquetVideoTextDataset( - args.data_path, - batch_size=args.train_batch_size, + fastvideo_args.data_path, + batch_size=fastvideo_args.train_batch_size, rank=self.rank, world_size=self.world_size, - cfg_rate=args.cfg, - num_latent_t=args.num_latent_t) + cfg_rate=fastvideo_args.cfg, + num_latent_t=fastvideo_args.num_latent_t) train_dataloader = StatefulDataLoader( train_dataset, - batch_size=args.train_batch_size, - num_workers=args.dataloader_num_workers, + batch_size=fastvideo_args.train_batch_size, + num_workers=fastvideo_args.dataloader_num_workers, prefetch_factor=2, shuffle=False, pin_memory=True, @@ -199,12 +209,14 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): self.noise_scheduler = noise_scheduler # Get unconditional embeddings - self.uncond_prompt_embed = train_dataset.uncond_prompt_embed - self.uncond_prompt_mask = train_dataset.uncond_prompt_mask + self.uncond_prompt_embed = torch.zeros(512, 4096).to(torch.float32) + self.uncond_prompt_mask = torch.zeros(1, 512).bool() + # self.uncond_prompt_embed = train_dataset.uncond_prompt_embed + # self.uncond_prompt_mask = train_dataset.uncond_prompt_mask if self.rank <= 0: - project = args.tracker_project_name or "fastvideo" - wandb.init(project=project, config=args) + project = fastvideo_args.tracker_project_name or "fastvideo" + wandb.init(project=project, config=fastvideo_args) @abstractmethod def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): @@ -228,7 +240,7 @@ def distill_one_step(self, transformer, model_type, teacher_transformer, def log_validation(self, transformer, fastvideo_args, global_step): """Log validation results during training.""" - fastvideo_args.inference_mode = True + fastvideo_args.mode = "inference" fastvideo_args.use_cpu_offload = False if not fastvideo_args.log_validation: return @@ -337,6 +349,7 @@ def log_validation(self, transformer, fastvideo_args, global_step): wandb.log(logs, step=global_step) # Re-enable gradients for training + fastvideo_args.mode = "distill" transformer.requires_grad_(True) transformer.train() @@ -392,6 +405,10 @@ def distill_one_step( pred_decay_weight, pred_decay_type, hunyuan_teacher_disable_cfg, + weighting_scheme, + logit_mean, + logit_std, + mode_scale, ): """Perform one step of distillation training.""" total_loss = 0.0 @@ -415,27 +432,54 @@ def distill_one_step( encoder_hidden_states = encoder_hidden_states.to( self.device, dtype=torch.bfloat16) - model_input = normalize_dit_input(model_type, latents) - noise = torch.randn_like(model_input) - bsz = model_input.shape[0] - index = torch.randint(0, - num_euler_timesteps, (bsz, ), - device=model_input.device).long() + latents = normalize_dit_input(model_type, latents) + batch_size = latents.shape[0] + noise = torch.randn_like(latents) + # u = compute_density_for_timestep_sampling( + # weighting_scheme=weighting_scheme, + # batch_size=batch_size, + # generator=noise_random_generator, + # logit_mean=logit_mean, + # logit_std=logit_std, + # mode_scale=mode_scale, + # ) + # indices = (u * noise_scheduler.config.num_train_timesteps).long() + # timesteps = noise_scheduler.timesteps[indices].to( + # device=latents.device) + # indices = indices.to(latents.device) + # if sp_size > 1: + # # Make sure that the timesteps are the same across all sp processes. + # sp_group = get_sp_group() + # sp_group.broadcast(timesteps, src=0) + # sigmas = get_sigmas( + # noise_scheduler, + # latents.device, + # timesteps, + # n_dim=latents.ndim, + # dtype=latents.dtype, + # ) + # noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise + + indices = torch.randint(0, + num_euler_timesteps, (batch_size, ), + device=latents.device).long() + if sp_size > 1: - self.sp_group.broadcast(index, src=0) + self.sp_group.broadcast(indices, src=0) # Add noise according to flow matching - sigmas = extract_into_tensor(solver.sigmas, index, - model_input.shape) - sigmas_prev = extract_into_tensor(solver.sigmas_prev, index, - model_input.shape) + sigmas = extract_into_tensor(solver.sigmas, indices, + latents.shape) + sigmas_prev = extract_into_tensor(solver.sigmas_prev, indices, + latents.shape) timesteps = (sigmas * noise_scheduler.config.num_train_timesteps).view(-1) timesteps_prev = ( sigmas_prev * noise_scheduler.config.num_train_timesteps).view(-1) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + noisy_model_input = noisy_model_input.to(torch.bfloat16) # Get student model prediction with torch.autocast("cuda", dtype=torch.bfloat16): @@ -458,11 +502,10 @@ def distill_one_step( # Apply multi-phase prediction model_pred, end_index = solver.euler_style_multiphase_pred( - noisy_model_input, model_pred, index, multiphase) + noisy_model_input, model_pred, indices, multiphase) - # Get teacher model guidance + # Get teacher model prediction with torch.no_grad(): - w = distill_cfg with torch.autocast("cuda", dtype=torch.bfloat16): with set_forward_context(current_timestep=timesteps, attn_metadata=None): @@ -479,21 +522,21 @@ def distill_one_step( else: # Get teacher model prediction on unconditional embedding with torch.autocast("cuda", dtype=torch.bfloat16): + input_kwargs = { + "hidden_states": noisy_model_input, + "encoder_hidden_states": uncond_prompt_embed.unsqueeze(0).expand( + batch_size, -1, -1), + "timestep": timesteps, + "encoder_attention_mask": uncond_prompt_mask.unsqueeze(0).expand(batch_size, -1), + "return_dict": False, + } with set_forward_context(current_timestep=timesteps, attn_metadata=None): - uncond_teacher_output = teacher_transformer( - noisy_model_input, - uncond_prompt_embed.unsqueeze(0).expand( - bsz, -1, -1), - timesteps, - uncond_prompt_mask.unsqueeze(0).expand(bsz, -1), - return_dict=False, - )[0].float() - - teacher_output = uncond_teacher_output + w * ( + uncond_teacher_output = teacher_transformer(**input_kwargs)[0] + teacher_output = uncond_teacher_output + distill_cfg * ( cond_teacher_output - uncond_teacher_output) x_prev = solver.euler_step(noisy_model_input, teacher_output, - index) + indices).to(torch.bfloat16) # Get target prediction with torch.no_grad(): @@ -503,7 +546,7 @@ def distill_one_step( current_timestep=timesteps_prev, attn_metadata=None): target_pred = ema_transformer( - x_prev.float(), + x_prev, encoder_hidden_states, timesteps_prev, encoder_attention_mask, @@ -514,7 +557,7 @@ def distill_one_step( current_timestep=timesteps_prev, attn_metadata=None): target_pred = transformer( - x_prev.float(), + x_prev, encoder_hidden_states, timesteps_prev, encoder_attention_mask, @@ -522,7 +565,7 @@ def distill_one_step( )[0] target, end_index = solver.euler_style_multiphase_pred( - x_prev, target_pred, index, multiphase, True) + x_prev, target_pred, indices, multiphase, True) # Calculate loss huber_c = 0.001 @@ -566,24 +609,28 @@ def distill_one_step( 1 - ema_decay)) # Gradient clipping and optimization step - model_parts = [transformer] - grad_norm = _clip_grad_norm_while_handling_failing_dtensor_cases( - [p for m in model_parts for p in m.parameters()], - max_grad_norm, - foreach=None, - ) + if max_grad_norm is not None: + model_parts = [transformer] + grad_norm = clip_grad_norm_while_handling_failing_dtensor_cases( + [p for m in model_parts for p in m.parameters()], + max_grad_norm, + foreach=None, + ) + grad_norm = grad_norm.item() if grad_norm is not None else 0.0 + else: + grad_norm = 0.0 optimizer.step() lr_scheduler.step() - return total_loss, grad_norm.item(), model_pred_norm + return total_loss, grad_norm, model_pred_norm def forward( self, batch: ForwardBatch, fastvideo_args: TrainingArgs, ): - args = fastvideo_args + assert self.training_args is not None train_dataloader = self.train_dataloader init_steps = self.init_steps lr_scheduler = self.lr_scheduler @@ -595,19 +642,19 @@ def forward( uncond_prompt_mask = self.uncond_prompt_mask # Train! - total_batch_size = (self.world_size * args.gradient_accumulation_steps / - args.sp_size * args.train_sp_batch_size) + total_batch_size = (self.world_size * self.training_args.gradient_accumulation_steps / + self.training_args.sp_size * self.training_args.train_sp_batch_size) logger.info("***** Running distillation training *****") logger.info(f" Resume training from step {init_steps}") logger.info( - f" Instantaneous batch size per device = {args.train_batch_size}") + f" Instantaneous batch size per device = {self.training_args.train_batch_size}") logger.info( f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" ) logger.info( - f" Gradient Accumulation steps = {args.gradient_accumulation_steps}" + f" Gradient Accumulation steps = {self.training_args.gradient_accumulation_steps}" ) - logger.info(f" Total optimization steps = {args.max_train_steps}") + logger.info(f" Total optimization steps = {self.training_args.max_train_steps}") logger.info( f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B" ) @@ -616,12 +663,12 @@ def forward( ) # Potentially load in the weights and states from a previous save - if args.resume_from_checkpoint: + if self.training_args.resume_from_checkpoint: raise NotImplementedError( "resume_from_checkpoint is not supported now.") progress_bar = tqdm( - range(0, args.max_train_steps), + range(0, self.training_args.max_train_steps), initial=init_steps, desc="Steps", disable=self.local_rank > 0, @@ -644,13 +691,12 @@ def get_num_phases(multi_phased_distill_schedule, step): return int(phase) return int(phase) - for step in range(init_steps + 1, args.max_train_steps + 1): + for step in range(init_steps + 1, self.training_args.max_train_steps + 1): start_time = time.perf_counter() - assert args.multi_phased_distill_schedule is not None - num_phases = get_num_phases(args.multi_phased_distill_schedule, + assert self.training_args.multi_phased_distill_schedule is not None + num_phases = get_num_phases(self.training_args.multi_phased_distill_schedule, step) - loss, grad_norm, pred_norm = self.distill_one_step( self.transformer, "wan", # model_type @@ -662,19 +708,23 @@ def get_num_phases(multi_phased_distill_schedule, step): noise_scheduler, solver, noise_random_generator, - args.gradient_accumulation_steps, - args.sp_size, - args.max_grad_norm, + self.training_args.gradient_accumulation_steps, + self.training_args.sp_size, + self.training_args.max_grad_norm, uncond_prompt_embed, uncond_prompt_mask, - args.num_euler_timesteps, + self.training_args.num_euler_timesteps, num_phases, - args.not_apply_cfg_solver, - args.distill_cfg, - args.ema_decay, - args.pred_decay_weight, - args.pred_decay_type, - args.hunyuan_teacher_disable_cfg, + self.training_args.not_apply_cfg_solver, + self.training_args.distill_cfg, + self.training_args.ema_decay, + self.training_args.pred_decay_weight, + self.training_args.pred_decay_type, + self.training_args.hunyuan_teacher_disable_cfg, + self.training_args.weighting_scheme, + self.training_args.logit_mean, + self.training_args.logit_std, + self.training_args.mode_scale, ) step_time = time.perf_counter() - start_time @@ -716,27 +766,27 @@ def get_num_phases(multi_phased_distill_schedule, step): step=step, ) - if step % args.checkpointing_steps == 0: - if args.use_lora: + if step % self.training_args.checkpointing_steps == 0: + if self.training_args.use_lora: raise NotImplementedError("LoRA is not supported now") else: - if args.use_ema: - save_checkpoint_v1(self.ema_transformer, self.rank, - args.output_dir, step) + if self.training_args.use_ema: + save_checkpoint(self.ema_transformer, self.rank, + self.training_args.output_dir, step) else: - save_checkpoint_v1(self.transformer, self.rank, - args.output_dir, step) + save_checkpoint(self.transformer, self.rank, + self.training_args.output_dir, step) self.sp_group.barrier() - if args.log_validation and step % args.validation_steps == 0: - self.log_validation(self.transformer, args, step) + if self.training_args.log_validation and step % self.training_args.validation_steps == 0: + self.log_validation(self.transformer, self.training_args, step) # Final checkpoint - if args.use_lora: + if self.training_args.use_lora: raise NotImplementedError("LoRA is not supported now") else: - save_checkpoint_v1(self.transformer, self.rank, args.output_dir, - args.max_train_steps) + save_checkpoint(self.transformer, self.rank, self.training_args.output_dir, + self.training_args.max_train_steps) if get_sp_group(): cleanup_dist_env_and_memory() @@ -748,7 +798,7 @@ def main(args): pipeline = WanDistillationPipeline.from_pretrained( args.pretrained_model_name_or_path, args=args) - args = pipeline.fastvideo_args + args = pipeline.training_args pipeline.forward(None, args) logger.info("Distillation pipeline done") diff --git a/scripts/distill/distill_v1.sh b/scripts/distill/distill_v1.sh old mode 100644 new mode 100755 index 328d84c9b..152dda9e2 --- a/scripts/distill/distill_v1.sh +++ b/scripts/distill/distill_v1.sh @@ -11,7 +11,8 @@ num_gpus=1 # --pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo \ # --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ torchrun --nnodes 1 --nproc_per_node $num_gpus\ - fastvideo/v1/pipelines/distillation_pipeline.py\ + fastvideo/v1/training/distillation_pipeline.py\ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ --mode distill\ --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ --cache_dir "/home/test/.cache"\ @@ -27,7 +28,7 @@ torchrun --nnodes 1 --nproc_per_node $num_gpus\ --learning_rate=1e-6\ --mixed_precision="bf16"\ --checkpointing_steps=64\ - --validation_steps 20\ + --validation_steps 80\ --validation_sampling_steps "2,4,8" \ --checkpoints_total_limit 3\ --allow_tf32\ @@ -44,4 +45,9 @@ torchrun --nnodes 1 --nproc_per_node $num_gpus\ --num_euler_timesteps 50 \ --multi_phased_distill_schedule "4000-1" \ --not_apply_cfg_solver \ - --master_weight_type "bf16" \ No newline at end of file + --weight_decay 0.01 \ + --master_weight_type "fp32" \ + --distill_cfg 3.0 \ + --pred_decay_weight 0.0 \ + --max_grad_norm 1.0 + # --master_weight_type "bf16" \ No newline at end of file From 6ec9d1d66bcc22a77df48bc03764481d4e9cd429 Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Wed, 4 Jun 2025 00:45:12 +0800 Subject: [PATCH 04/14] [WIP][Fix] order of sigmas, dataloader issue to debug --- fastvideo/distill/solver.py | 5 +++++ fastvideo/v1/training/distillation_pipeline.py | 16 ++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/fastvideo/distill/solver.py b/fastvideo/distill/solver.py index d7c89b04a..1082bc4ac 100644 --- a/fastvideo/distill/solver.py +++ b/fastvideo/distill/solver.py @@ -242,11 +242,16 @@ class EulerSolver: def __init__(self, sigmas, timesteps=1000, euler_timesteps=50): self.step_ratio = timesteps // euler_timesteps + self.euler_timesteps = (np.arange(1, euler_timesteps + 1) * self.step_ratio).round().astype(np.int64) - 1 self.euler_timesteps_prev = np.asarray([0] + self.euler_timesteps[:-1].tolist()) self.sigmas = sigmas[self.euler_timesteps] self.sigmas_prev = np.asarray([sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist()) # either use sigma0 or 0 + print(f"sigmas: {sigmas}") + print(f"euler_timesteps: {self.euler_timesteps}") + print(f"sigmas: {self.sigmas}") + print(f"sigmas_prev: {self.sigmas_prev}") self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long() self.euler_timesteps_prev = torch.from_numpy(self.euler_timesteps_prev).long() diff --git a/fastvideo/v1/training/distillation_pipeline.py b/fastvideo/v1/training/distillation_pipeline.py index ee5e6791d..eacf14549 100644 --- a/fastvideo/v1/training/distillation_pipeline.py +++ b/fastvideo/v1/training/distillation_pipeline.py @@ -150,7 +150,7 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): sigmas = noise_scheduler.sigmas self.solver = EulerSolver( - sigmas.numpy()[::-1], + sigmas.numpy(), noise_scheduler.config.num_train_timesteps, euler_timesteps=fastvideo_args.num_euler_timesteps, ) @@ -324,7 +324,19 @@ def log_validation(self, transformer, fastvideo_args, global_step): x = torchvision.utils.make_grid(x, nrow=6) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) frames.append((x * 255).numpy().astype(np.uint8)) - videos.append(frames) + # videos.append(frames) + videos = [frames] + + video_filenames = [] + video_captions = [] + for i, video in enumerate(videos): + caption = captions[i] + filename = os.path.join( + fastvideo_args.output_dir, + f"validation_step_{global_step}_video_{i}.mp4") + imageio.mimsave(filename, video, fps=sampling_param.fps) + video_filenames.append(filename) + video_captions.append(caption) # Log validation results if self.rank == 0: From 6fa53cbf136793892d994f6e59abb009ee45ac7b Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Thu, 5 Jun 2025 18:58:15 +0800 Subject: [PATCH 05/14] [Fix] dataloader issue, distill for single gpu ready --- fastvideo/v1/dataset/parquet_datasets.py | 1 + .../v1/training/distillation_pipeline.py | 498 +----------------- .../v1/training/wan_distillation_pipeline.py | 485 +++++++++++++++++ scripts/distill/distill_v1.sh | 42 +- 4 files changed, 509 insertions(+), 517 deletions(-) create mode 100644 fastvideo/v1/training/wan_distillation_pipeline.py diff --git a/fastvideo/v1/dataset/parquet_datasets.py b/fastvideo/v1/dataset/parquet_datasets.py index a19db5038..772512fd3 100644 --- a/fastvideo/v1/dataset/parquet_datasets.py +++ b/fastvideo/v1/dataset/parquet_datasets.py @@ -119,6 +119,7 @@ def __init__(self, plan = json.load(f) self.neg_metadata = plan["negative_prompt"][0] + # Add unconditional embeddings for distillation (like in LatentDataset) self.uncond_prompt_embed = torch.zeros(512, 4096).to(torch.float32) self.uncond_prompt_mask = torch.zeros(1, 512).bool() diff --git a/fastvideo/v1/training/distillation_pipeline.py b/fastvideo/v1/training/distillation_pipeline.py index eacf14549..fe08911b0 100644 --- a/fastvideo/v1/training/distillation_pipeline.py +++ b/fastvideo/v1/training/distillation_pipeline.py @@ -1,10 +1,6 @@ import gc import os -import sys -import time from abc import ABC, abstractmethod -from collections import deque -from copy import deepcopy import imageio import numpy as np @@ -16,23 +12,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ShardingStrategy from torchdata.stateful_dataloader import StatefulDataLoader -from tqdm.auto import tqdm -from fastvideo.distill.solver import EulerSolver, extract_into_tensor -from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input +from fastvideo.distill.solver import EulerSolver from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset -from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group +from fastvideo.v1.distributed import get_sp_group from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs -from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines import ComposedPipelineBase from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch -from fastvideo.v1.training.training_utils import ( - clip_grad_norm_while_handling_failing_dtensor_cases, - compute_density_for_timestep_sampling, get_sigmas, normalize_dit_input, - save_checkpoint) -from fastvideo.v1.pipelines.wan.wan_pipeline import WanValidationPipeline logger = init_logger(__name__) @@ -66,31 +54,6 @@ def reshard_fsdp(model): if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD: torch.distributed.fsdp._runtime_utils._reshard(m, m._handle, True) - -def get_norm(model_pred, norms, gradient_accumulation_steps): - """Calculate and aggregate model prediction norms.""" - fro_norm = ( - torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore - gradient_accumulation_steps) - largest_singular_value = (torch.linalg.matrix_norm(model_pred, ord=2) / - gradient_accumulation_steps) - absolute_mean = torch.mean( - torch.abs(model_pred)) / gradient_accumulation_steps - absolute_max = torch.max( - torch.abs(model_pred)) / gradient_accumulation_steps - - sp_group = get_sp_group() - sp_group.all_reduce(fro_norm, op=torch.distributed.ReduceOp.AVG) - sp_group.all_reduce(largest_singular_value, - op=torch.distributed.ReduceOp.AVG) - sp_group.all_reduce(absolute_mean, op=torch.distributed.ReduceOp.AVG) - - norms["fro"] += torch.mean(fro_norm).item() # codespell:ignore - norms["largest singular value"] += torch.mean(largest_singular_value).item() - norms["absolute mean"] += absolute_mean.item() - norms["absolute max"] += absolute_max.item() - - class DistillationPipeline(ComposedPipelineBase, ABC): """ A pipeline for distillation training. All distillation pipelines should inherit from this class. @@ -369,460 +332,3 @@ def log_validation(self, transformer, fastvideo_args, global_step): torch.cuda.empty_cache() -class WanDistillationPipeline(DistillationPipeline): - """ - A distillation pipeline for Wan. - """ - _required_config_modules = ["scheduler", "transformer"] - - def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): - pass - - def create_training_stages(self, fastvideo_args: FastVideoArgs): - pass - - def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): - logger.info("Initializing validation pipeline...") - args_copy = deepcopy(fastvideo_args) - - args_copy.mode = "inference" - args_copy.vae_config.load_encoder = False - validation_pipeline = WanValidationPipeline.from_pretrained( - fastvideo_args.model_path, args=args_copy) - - self.validation_pipeline = validation_pipeline - - def distill_one_step( - self, - transformer, - model_type, - teacher_transformer, - ema_transformer, - optimizer, - lr_scheduler, - loader_iter, - noise_scheduler, - solver, - noise_random_generator, - gradient_accumulation_steps, - sp_size, - max_grad_norm, - uncond_prompt_embed, - uncond_prompt_mask, - num_euler_timesteps, - multiphase, - not_apply_cfg_solver, - distill_cfg, - ema_decay, - pred_decay_weight, - pred_decay_type, - hunyuan_teacher_disable_cfg, - weighting_scheme, - logit_mean, - logit_std, - mode_scale, - ): - """Perform one step of distillation training.""" - total_loss = 0.0 - optimizer.zero_grad() - model_pred_norm = { - "fro": 0.0, # codespell:ignore - "largest singular value": 0.0, - "absolute mean": 0.0, - "absolute max": 0.0, - } - - for _ in range(gradient_accumulation_steps): - ( - latents, - encoder_hidden_states, - encoder_attention_mask, - infos, - ) = next(loader_iter) - - latents = latents.to(self.device, dtype=torch.bfloat16) - encoder_hidden_states = encoder_hidden_states.to( - self.device, dtype=torch.bfloat16) - - latents = normalize_dit_input(model_type, latents) - batch_size = latents.shape[0] - noise = torch.randn_like(latents) - # u = compute_density_for_timestep_sampling( - # weighting_scheme=weighting_scheme, - # batch_size=batch_size, - # generator=noise_random_generator, - # logit_mean=logit_mean, - # logit_std=logit_std, - # mode_scale=mode_scale, - # ) - # indices = (u * noise_scheduler.config.num_train_timesteps).long() - # timesteps = noise_scheduler.timesteps[indices].to( - # device=latents.device) - # indices = indices.to(latents.device) - # if sp_size > 1: - # # Make sure that the timesteps are the same across all sp processes. - # sp_group = get_sp_group() - # sp_group.broadcast(timesteps, src=0) - # sigmas = get_sigmas( - # noise_scheduler, - # latents.device, - # timesteps, - # n_dim=latents.ndim, - # dtype=latents.dtype, - # ) - # noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise - - indices = torch.randint(0, - num_euler_timesteps, (batch_size, ), - device=latents.device).long() - - if sp_size > 1: - self.sp_group.broadcast(indices, src=0) - - # Add noise according to flow matching - sigmas = extract_into_tensor(solver.sigmas, indices, - latents.shape) - sigmas_prev = extract_into_tensor(solver.sigmas_prev, indices, - latents.shape) - - timesteps = (sigmas * - noise_scheduler.config.num_train_timesteps).view(-1) - timesteps_prev = ( - sigmas_prev * - noise_scheduler.config.num_train_timesteps).view(-1) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents - noisy_model_input = noisy_model_input.to(torch.bfloat16) - - # Get student model prediction - with torch.autocast("cuda", dtype=torch.bfloat16): - input_kwargs = { - "hidden_states": noisy_model_input, - "encoder_hidden_states": encoder_hidden_states, - "timestep": timesteps, - "encoder_attention_mask": encoder_attention_mask, - "return_dict": False, - } - if hunyuan_teacher_disable_cfg: - input_kwargs["guidance"] = torch.tensor( - [1000.0], - device=noisy_model_input.device, - dtype=torch.bfloat16) - - with set_forward_context(current_timestep=timesteps, - attn_metadata=None): - model_pred = transformer(**input_kwargs)[0] - - # Apply multi-phase prediction - model_pred, end_index = solver.euler_style_multiphase_pred( - noisy_model_input, model_pred, indices, multiphase) - - # Get teacher model prediction - with torch.no_grad(): - with torch.autocast("cuda", dtype=torch.bfloat16): - with set_forward_context(current_timestep=timesteps, - attn_metadata=None): - cond_teacher_output = teacher_transformer( - noisy_model_input, - encoder_hidden_states, - timesteps, - encoder_attention_mask, - return_dict=False, - )[0].float() - - if not_apply_cfg_solver: - uncond_teacher_output = cond_teacher_output - else: - # Get teacher model prediction on unconditional embedding - with torch.autocast("cuda", dtype=torch.bfloat16): - input_kwargs = { - "hidden_states": noisy_model_input, - "encoder_hidden_states": uncond_prompt_embed.unsqueeze(0).expand( - batch_size, -1, -1), - "timestep": timesteps, - "encoder_attention_mask": uncond_prompt_mask.unsqueeze(0).expand(batch_size, -1), - "return_dict": False, - } - with set_forward_context(current_timestep=timesteps, - attn_metadata=None): - uncond_teacher_output = teacher_transformer(**input_kwargs)[0] - teacher_output = uncond_teacher_output + distill_cfg * ( - cond_teacher_output - uncond_teacher_output) - x_prev = solver.euler_step(noisy_model_input, teacher_output, - indices).to(torch.bfloat16) - - # Get target prediction - with torch.no_grad(): - with torch.autocast("cuda", dtype=torch.bfloat16): - if ema_transformer is not None: - with set_forward_context( - current_timestep=timesteps_prev, - attn_metadata=None): - target_pred = ema_transformer( - x_prev, - encoder_hidden_states, - timesteps_prev, - encoder_attention_mask, - return_dict=False, - )[0] - else: - with set_forward_context( - current_timestep=timesteps_prev, - attn_metadata=None): - target_pred = transformer( - x_prev, - encoder_hidden_states, - timesteps_prev, - encoder_attention_mask, - return_dict=False, - )[0] - - target, end_index = solver.euler_style_multiphase_pred( - x_prev, target_pred, indices, multiphase, True) - - # Calculate loss - huber_c = 0.001 - loss = (torch.mean( - torch.sqrt((model_pred.float() - target.float())**2 + - huber_c**2) - huber_c) / gradient_accumulation_steps) - - if pred_decay_weight > 0: - if pred_decay_type == "l1": - pred_decay_loss = ( - torch.mean(torch.sqrt(model_pred.float()**2)) * - pred_decay_weight / gradient_accumulation_steps) - loss += pred_decay_loss - elif pred_decay_type == "l2": - pred_decay_loss = (torch.mean(model_pred.float()**2) * - pred_decay_weight / - gradient_accumulation_steps) - loss += pred_decay_loss - else: - raise NotImplementedError( - "pred_decay_type is not implemented") - - # Calculate model prediction norms - get_norm(model_pred.detach().float(), model_pred_norm, - gradient_accumulation_steps) - loss.backward() - - avg_loss = loss.detach().clone() - self.sp_group.all_reduce(avg_loss, - op=torch.distributed.ReduceOp.AVG) - total_loss += avg_loss.item() - - # Update EMA - if ema_transformer is not None: - reshard_fsdp(ema_transformer) - for p_averaged, p_model in zip(ema_transformer.parameters(), - transformer.parameters()): - with torch.no_grad(): - p_averaged.copy_( - torch.lerp(p_averaged.detach(), p_model.detach(), - 1 - ema_decay)) - - # Gradient clipping and optimization step - if max_grad_norm is not None: - model_parts = [transformer] - grad_norm = clip_grad_norm_while_handling_failing_dtensor_cases( - [p for m in model_parts for p in m.parameters()], - max_grad_norm, - foreach=None, - ) - grad_norm = grad_norm.item() if grad_norm is not None else 0.0 - else: - grad_norm = 0.0 - - optimizer.step() - lr_scheduler.step() - - return total_loss, grad_norm, model_pred_norm - - def forward( - self, - batch: ForwardBatch, - fastvideo_args: TrainingArgs, - ): - assert self.training_args is not None - train_dataloader = self.train_dataloader - init_steps = self.init_steps - lr_scheduler = self.lr_scheduler - optimizer = self.optimizer - noise_scheduler = self.noise_scheduler - solver = self.solver - noise_random_generator = None - uncond_prompt_embed = self.uncond_prompt_embed - uncond_prompt_mask = self.uncond_prompt_mask - - # Train! - total_batch_size = (self.world_size * self.training_args.gradient_accumulation_steps / - self.training_args.sp_size * self.training_args.train_sp_batch_size) - logger.info("***** Running distillation training *****") - logger.info(f" Resume training from step {init_steps}") - logger.info( - f" Instantaneous batch size per device = {self.training_args.train_batch_size}") - logger.info( - f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" - ) - logger.info( - f" Gradient Accumulation steps = {self.training_args.gradient_accumulation_steps}" - ) - logger.info(f" Total optimization steps = {self.training_args.max_train_steps}") - logger.info( - f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B" - ) - logger.info( - f" Master weight dtype: {self.transformer.parameters().__next__().dtype}" - ) - - # Potentially load in the weights and states from a previous save - if self.training_args.resume_from_checkpoint: - raise NotImplementedError( - "resume_from_checkpoint is not supported now.") - - progress_bar = tqdm( - range(0, self.training_args.max_train_steps), - initial=init_steps, - desc="Steps", - disable=self.local_rank > 0, - ) - - loader_iter = iter(train_dataloader) - step_times = deque(maxlen=100) - - # Skip steps if resuming - for i in range(init_steps): - next(loader_iter) - - def get_num_phases(multi_phased_distill_schedule, step): - # step-phase,step-phase - multi_phases = multi_phased_distill_schedule.split(",") - phase = multi_phases[-1].split("-")[-1] - for step_phases in multi_phases: - phase_step, phase = step_phases.split("-") - if step <= int(phase_step): - return int(phase) - return int(phase) - - for step in range(init_steps + 1, self.training_args.max_train_steps + 1): - start_time = time.perf_counter() - - assert self.training_args.multi_phased_distill_schedule is not None - num_phases = get_num_phases(self.training_args.multi_phased_distill_schedule, - step) - loss, grad_norm, pred_norm = self.distill_one_step( - self.transformer, - "wan", # model_type - self.teacher_transformer, - self.ema_transformer, - optimizer, - lr_scheduler, - loader_iter, - noise_scheduler, - solver, - noise_random_generator, - self.training_args.gradient_accumulation_steps, - self.training_args.sp_size, - self.training_args.max_grad_norm, - uncond_prompt_embed, - uncond_prompt_mask, - self.training_args.num_euler_timesteps, - num_phases, - self.training_args.not_apply_cfg_solver, - self.training_args.distill_cfg, - self.training_args.ema_decay, - self.training_args.pred_decay_weight, - self.training_args.pred_decay_type, - self.training_args.hunyuan_teacher_disable_cfg, - self.training_args.weighting_scheme, - self.training_args.logit_mean, - self.training_args.logit_std, - self.training_args.mode_scale, - ) - - step_time = time.perf_counter() - start_time - step_times.append(step_time) - avg_step_time = sum(step_times) / len(step_times) - - progress_bar.set_postfix({ - "loss": f"{loss:.4f}", - "step_time": f"{step_time:.2f}s", - "grad_norm": grad_norm, - "phases": num_phases, - }) - progress_bar.update(1) - - if self.rank <= 0: - wandb.log( - { - "train_loss": - loss, - "learning_rate": - lr_scheduler.get_last_lr()[0], - "step_time": - step_time, - "avg_step_time": - avg_step_time, - "grad_norm": - grad_norm, - "pred_fro_norm": - pred_norm["fro"], # codespell:ignore - "pred_largest_singular_value": - pred_norm["largest singular value"], - "pred_absolute_mean": - pred_norm["absolute mean"], - "pred_absolute_max": - pred_norm["absolute max"], - "phases": - num_phases, - }, - step=step, - ) - - if step % self.training_args.checkpointing_steps == 0: - if self.training_args.use_lora: - raise NotImplementedError("LoRA is not supported now") - else: - if self.training_args.use_ema: - save_checkpoint(self.ema_transformer, self.rank, - self.training_args.output_dir, step) - else: - save_checkpoint(self.transformer, self.rank, - self.training_args.output_dir, step) - self.sp_group.barrier() - - if self.training_args.log_validation and step % self.training_args.validation_steps == 0: - self.log_validation(self.transformer, self.training_args, step) - - # Final checkpoint - if self.training_args.use_lora: - raise NotImplementedError("LoRA is not supported now") - else: - save_checkpoint(self.transformer, self.rank, self.training_args.output_dir, - self.training_args.max_train_steps) - - if get_sp_group(): - cleanup_dist_env_and_memory() - - -def main(args): - logger.info("Starting distillation pipeline...") - - pipeline = WanDistillationPipeline.from_pretrained( - args.pretrained_model_name_or_path, args=args) - - args = pipeline.training_args - pipeline.forward(None, args) - logger.info("Distillation pipeline done") - - -if __name__ == "__main__": - argv = sys.argv - from fastvideo.v1.fastvideo_args import TrainingArgs - from fastvideo.v1.utils import FlexibleArgumentParser - parser = FlexibleArgumentParser() - parser = TrainingArgs.add_cli_args(parser) - parser = FastVideoArgs.add_cli_args(parser) - args = parser.parse_args() - args.use_cpu_offload = False - print(args) - main(args) diff --git a/fastvideo/v1/training/wan_distillation_pipeline.py b/fastvideo/v1/training/wan_distillation_pipeline.py new file mode 100644 index 000000000..1dca8f931 --- /dev/null +++ b/fastvideo/v1/training/wan_distillation_pipeline.py @@ -0,0 +1,485 @@ +import sys +import time +from collections import deque +from copy import deepcopy + +import torch +import wandb +from tqdm.auto import tqdm + +from fastvideo.distill.solver import extract_into_tensor +from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group +from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.v1.forward_context import set_forward_context +from fastvideo.v1.logger import init_logger +from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.v1.training.training_utils import ( + clip_grad_norm_while_handling_failing_dtensor_cases, + save_checkpoint, normalize_dit_input) +from fastvideo.v1.pipelines.wan.wan_pipeline import WanValidationPipeline +from fastvideo.v1.training.distillation_pipeline import DistillationPipeline, reshard_fsdp + +logger = init_logger(__name__) + +def get_norm(model_pred, norms, gradient_accumulation_steps): + """Calculate and aggregate model prediction norms.""" + fro_norm = ( + torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore + gradient_accumulation_steps) + largest_singular_value = (torch.linalg.matrix_norm(model_pred, ord=2) / + gradient_accumulation_steps) + absolute_mean = torch.mean( + torch.abs(model_pred)) / gradient_accumulation_steps + absolute_max = torch.max( + torch.abs(model_pred)) / gradient_accumulation_steps + + sp_group = get_sp_group() + sp_group.all_reduce(fro_norm, op=torch.distributed.ReduceOp.AVG) + sp_group.all_reduce(largest_singular_value, + op=torch.distributed.ReduceOp.AVG) + sp_group.all_reduce(absolute_mean, op=torch.distributed.ReduceOp.AVG) + + norms["fro"] += torch.mean(fro_norm).item() # codespell:ignore + norms["largest singular value"] += torch.mean(largest_singular_value).item() + norms["absolute mean"] += absolute_mean.item() + norms["absolute max"] += absolute_max.item() + +class WanDistillationPipeline(DistillationPipeline): + """ + A distillation pipeline for Wan. + """ + _required_config_modules = ["scheduler", "transformer"] + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + pass + + def create_training_stages(self, fastvideo_args: FastVideoArgs): + pass + + def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): + logger.info("Initializing validation pipeline...") + args_copy = deepcopy(fastvideo_args) + + args_copy.mode = "inference" + args_copy.vae_config.load_encoder = False + validation_pipeline = WanValidationPipeline.from_pretrained( + fastvideo_args.model_path, args=args_copy) + + self.validation_pipeline = validation_pipeline + + def distill_one_step( + self, + transformer, + model_type, + teacher_transformer, + ema_transformer, + optimizer, + lr_scheduler, + loader_iter, + noise_scheduler, + solver, + noise_random_generator, + gradient_accumulation_steps, + sp_size, + max_grad_norm, + uncond_prompt_embed, + uncond_prompt_mask, + num_euler_timesteps, + multiphase, + not_apply_cfg_solver, + distill_cfg, + ema_decay, + pred_decay_weight, + pred_decay_type, + hunyuan_teacher_disable_cfg, + weighting_scheme, + logit_mean, + logit_std, + mode_scale, + ): + """Perform one step of distillation training.""" + total_loss = 0.0 + optimizer.zero_grad() + model_pred_norm = { + "fro": 0.0, # codespell:ignore + "largest singular value": 0.0, + "absolute mean": 0.0, + "absolute max": 0.0, + } + + for _ in range(gradient_accumulation_steps): + ( + latents, + encoder_hidden_states, + encoder_attention_mask, + infos, + ) = next(loader_iter) + + latents = latents.to(self.device, dtype=torch.bfloat16) + encoder_hidden_states = encoder_hidden_states.to( + self.device, dtype=torch.bfloat16) + + latents = normalize_dit_input(model_type, latents) + batch_size = latents.shape[0] + noise = torch.randn_like(latents) + + indices = torch.randint(0, + num_euler_timesteps, (batch_size, ), + device=latents.device).long() + + if sp_size > 1: + self.sp_group.broadcast(indices, src=0) + + # Add noise according to flow matching + sigmas = extract_into_tensor(solver.sigmas, indices, + latents.shape) + sigmas_prev = extract_into_tensor(solver.sigmas_prev, indices, + latents.shape) + + timesteps = (sigmas * + noise_scheduler.config.num_train_timesteps).view(-1) + timesteps_prev = ( + sigmas_prev * + noise_scheduler.config.num_train_timesteps).view(-1) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + noisy_model_input = noisy_model_input.to(torch.bfloat16) + + # Get student model prediction + with torch.autocast("cuda", dtype=torch.bfloat16): + input_kwargs = { + "hidden_states": noisy_model_input, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timesteps, + "encoder_attention_mask": encoder_attention_mask, + "return_dict": False, + } + if hunyuan_teacher_disable_cfg: + input_kwargs["guidance"] = torch.tensor( + [1000.0], + device=noisy_model_input.device, + dtype=torch.bfloat16) + + with set_forward_context(current_timestep=timesteps, + attn_metadata=None): + model_pred = transformer(**input_kwargs)[0] + + # Apply multi-phase prediction + model_pred, end_index = solver.euler_style_multiphase_pred( + noisy_model_input, model_pred, indices, multiphase) + + # Get teacher model prediction + with torch.no_grad(): + with torch.autocast("cuda", dtype=torch.bfloat16): + with set_forward_context(current_timestep=timesteps, + attn_metadata=None): + cond_teacher_output = teacher_transformer( + noisy_model_input, + encoder_hidden_states, + timesteps, + encoder_attention_mask, + return_dict=False, + )[0].float() + + if not_apply_cfg_solver: + uncond_teacher_output = cond_teacher_output + else: + # Get teacher model prediction on unconditional embedding + with torch.autocast("cuda", dtype=torch.bfloat16): + input_kwargs = { + "hidden_states": noisy_model_input, + "encoder_hidden_states": uncond_prompt_embed.unsqueeze(0).expand( + batch_size, -1, -1), + "timestep": timesteps, + "encoder_attention_mask": uncond_prompt_mask.unsqueeze(0).expand(batch_size, -1), + "return_dict": False, + } + with set_forward_context(current_timestep=timesteps, + attn_metadata=None): + uncond_teacher_output = teacher_transformer(**input_kwargs)[0] + teacher_output = uncond_teacher_output + distill_cfg * ( + cond_teacher_output - uncond_teacher_output) + x_prev = solver.euler_step(noisy_model_input, teacher_output, + indices).to(torch.bfloat16) + + # Get target prediction + with torch.no_grad(): + with torch.autocast("cuda", dtype=torch.bfloat16): + if ema_transformer is not None: + with set_forward_context( + current_timestep=timesteps_prev, + attn_metadata=None): + target_pred = ema_transformer( + x_prev, + encoder_hidden_states, + timesteps_prev, + encoder_attention_mask, + return_dict=False, + )[0] + else: + with set_forward_context( + current_timestep=timesteps_prev, + attn_metadata=None): + target_pred = transformer( + x_prev, + encoder_hidden_states, + timesteps_prev, + encoder_attention_mask, + return_dict=False, + )[0] + + target, end_index = solver.euler_style_multiphase_pred( + x_prev, target_pred, indices, multiphase, True) + + # Calculate loss + huber_c = 0.001 + loss = (torch.mean( + torch.sqrt((model_pred.float() - target.float())**2 + + huber_c**2) - huber_c) / gradient_accumulation_steps) + + if pred_decay_weight > 0: + if pred_decay_type == "l1": + pred_decay_loss = ( + torch.mean(torch.sqrt(model_pred.float()**2)) * + pred_decay_weight / gradient_accumulation_steps) + loss += pred_decay_loss + elif pred_decay_type == "l2": + pred_decay_loss = (torch.mean(model_pred.float()**2) * + pred_decay_weight / + gradient_accumulation_steps) + loss += pred_decay_loss + else: + raise NotImplementedError( + "pred_decay_type is not implemented") + + # Calculate model prediction norms + get_norm(model_pred.detach().float(), model_pred_norm, + gradient_accumulation_steps) + loss.backward() + + avg_loss = loss.detach().clone() + self.sp_group.all_reduce(avg_loss, + op=torch.distributed.ReduceOp.AVG) + total_loss += avg_loss.item() + + # Update EMA + if ema_transformer is not None: + reshard_fsdp(ema_transformer) + for p_averaged, p_model in zip(ema_transformer.parameters(), + transformer.parameters()): + with torch.no_grad(): + p_averaged.copy_( + torch.lerp(p_averaged.detach(), p_model.detach(), + 1 - ema_decay)) + + # Gradient clipping and optimization step + if max_grad_norm is not None: + model_parts = [transformer] + grad_norm = clip_grad_norm_while_handling_failing_dtensor_cases( + [p for m in model_parts for p in m.parameters()], + max_grad_norm, + foreach=None, + ) + grad_norm = grad_norm.item() if grad_norm is not None else 0.0 + else: + grad_norm = 0.0 + + optimizer.step() + lr_scheduler.step() + + return total_loss, grad_norm, model_pred_norm + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: TrainingArgs, + ): + assert self.training_args is not None + train_dataloader = self.train_dataloader + init_steps = self.init_steps + lr_scheduler = self.lr_scheduler + optimizer = self.optimizer + noise_scheduler = self.noise_scheduler + solver = self.solver + noise_random_generator = None + uncond_prompt_embed = self.uncond_prompt_embed + uncond_prompt_mask = self.uncond_prompt_mask + + # Train! + total_batch_size = (self.world_size * self.training_args.gradient_accumulation_steps / + self.training_args.sp_size * self.training_args.train_sp_batch_size) + logger.info("***** Running distillation training *****") + logger.info(f" Resume training from step {init_steps}") + logger.info( + f" Instantaneous batch size per device = {self.training_args.train_batch_size}") + logger.info( + f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {self.training_args.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {self.training_args.max_train_steps}") + logger.info( + f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B" + ) + logger.info( + f" Master weight dtype: {self.transformer.parameters().__next__().dtype}" + ) + + # Potentially load in the weights and states from a previous save + if self.training_args.resume_from_checkpoint: + raise NotImplementedError( + "resume_from_checkpoint is not supported now.") + + progress_bar = tqdm( + range(0, self.training_args.max_train_steps), + initial=init_steps, + desc="Steps", + disable=self.local_rank > 0, + ) + + loader_iter = iter(train_dataloader) + step_times = deque(maxlen=100) + + # Skip steps if resuming + for i in range(init_steps): + next(loader_iter) + + def get_num_phases(multi_phased_distill_schedule, step): + # step-phase,step-phase + multi_phases = multi_phased_distill_schedule.split(",") + phase = multi_phases[-1].split("-")[-1] + for step_phases in multi_phases: + phase_step, phase = step_phases.split("-") + if step <= int(phase_step): + return int(phase) + return int(phase) + + for step in range(init_steps + 1, self.training_args.max_train_steps + 1): + start_time = time.perf_counter() + + assert self.training_args.multi_phased_distill_schedule is not None + num_phases = get_num_phases(self.training_args.multi_phased_distill_schedule, + step) + try: + loss, grad_norm, pred_norm = self.distill_one_step( + self.transformer, + "wan", # model_type + self.teacher_transformer, + self.ema_transformer, + optimizer, + lr_scheduler, + loader_iter, + noise_scheduler, + solver, + noise_random_generator, + self.training_args.gradient_accumulation_steps, + self.training_args.sp_size, + self.training_args.max_grad_norm, + uncond_prompt_embed, + uncond_prompt_mask, + self.training_args.num_euler_timesteps, + num_phases, + self.training_args.not_apply_cfg_solver, + self.training_args.distill_cfg, + self.training_args.ema_decay, + self.training_args.pred_decay_weight, + self.training_args.pred_decay_type, + self.training_args.hunyuan_teacher_disable_cfg, + self.training_args.weighting_scheme, + self.training_args.logit_mean, + self.training_args.logit_std, + self.training_args.mode_scale, + ) + + step_time = time.perf_counter() - start_time + step_times.append(step_time) + avg_step_time = sum(step_times) / len(step_times) + + progress_bar.set_postfix({ + "loss": f"{loss:.4f}", + "step_time": f"{step_time:.2f}s", + "grad_norm": grad_norm, + "phases": num_phases, + }) + progress_bar.update(1) + except StopIteration: + loader_iter = iter(train_dataloader) + step -= 1 + continue + + + if self.rank <= 0: + wandb.log( + { + "train_loss": + loss, + "learning_rate": + lr_scheduler.get_last_lr()[0], + "step_time": + step_time, + "avg_step_time": + avg_step_time, + "grad_norm": + grad_norm, + "pred_fro_norm": + pred_norm["fro"], # codespell:ignore + "pred_largest_singular_value": + pred_norm["largest singular value"], + "pred_absolute_mean": + pred_norm["absolute mean"], + "pred_absolute_max": + pred_norm["absolute max"], + "phases": + num_phases, + }, + step=step, + ) + + if step % self.training_args.checkpointing_steps == 0: + if self.training_args.use_lora: + raise NotImplementedError("LoRA is not supported now") + else: + if self.training_args.use_ema: + save_checkpoint(self.ema_transformer, self.rank, + self.training_args.output_dir, step) + else: + save_checkpoint(self.transformer, self.rank, + self.training_args.output_dir, step) + self.sp_group.barrier() + + if self.training_args.log_validation and step % self.training_args.validation_steps == 0: + self.log_validation(self.transformer, self.training_args, step) + + # Final checkpoint + if self.training_args.use_lora: + raise NotImplementedError("LoRA is not supported now") + else: + save_checkpoint(self.transformer, self.rank, self.training_args.output_dir, + self.training_args.max_train_steps) + + if get_sp_group(): + cleanup_dist_env_and_memory() + + +def main(args): + logger.info("Starting distillation pipeline...") + + pipeline = WanDistillationPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + + args = pipeline.training_args + pipeline.forward(None, args) + logger.info("Distillation pipeline done") + + +if __name__ == "__main__": + argv = sys.argv + from fastvideo.v1.fastvideo_args import TrainingArgs + from fastvideo.v1.utils import FlexibleArgumentParser + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + args.use_cpu_offload = False + print(args) + main(args) diff --git a/scripts/distill/distill_v1.sh b/scripts/distill/distill_v1.sh index 152dda9e2..bfcb1af35 100755 --- a/scripts/distill/distill_v1.sh +++ b/scripts/distill/distill_v1.sh @@ -11,35 +11,35 @@ num_gpus=1 # --pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo \ # --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ torchrun --nnodes 1 --nproc_per_node $num_gpus\ - fastvideo/v1/training/distillation_pipeline.py\ + fastvideo/v1/training/wan_distillation_pipeline.py \ --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --mode distill\ + --mode distill \ --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --cache_dir "/home/test/.cache"\ - --data_path "$DATA_DIR"\ - --validation_prompt_dir "$VALIDATION_DIR"\ + --cache_dir "/home/test/.cache" \ + --data_path "$DATA_DIR" \ + --validation_prompt_dir "$VALIDATION_DIR" \ --train_batch_size=1 \ - --num_latent_t 1 \ + --num_latent_t 4 \ --sp_size $num_gpus \ - --train_sp_batch_size 1\ - --dataloader_num_workers $num_gpus\ - --gradient_accumulation_steps=1\ - --max_train_steps=320\ - --learning_rate=1e-6\ - --mixed_precision="bf16"\ - --checkpointing_steps=64\ - --validation_steps 80\ + --train_sp_batch_size 1 \ + --dataloader_num_workers $num_gpus \ + --gradient_accumulation_steps=1 \ + --max_train_steps=540 \ + --learning_rate=1e-6 \ + --mixed_precision="bf16" \ + --checkpointing_steps=64 \ + --validation_steps 180 \ --validation_sampling_steps "2,4,8" \ - --checkpoints_total_limit 3\ - --allow_tf32\ - --ema_start_step 0\ - --cfg 0.0\ - --log_validation\ - --output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_16_HD"\ + --checkpoints_total_limit 3 \ + --allow_tf32 \ + --ema_start_step 0 \ + --cfg 0.0 \ + --log_validation \ + --output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_16_HD" \ --tracker_project_name Hunyuan_Distill \ --num_height 720 \ --num_width 1280 \ - --num_frames 125 \ + --num_frames 81 \ --shift 17 \ --validation_guidance_scale "1.0" \ --num_euler_timesteps 50 \ From 3a831309dfbd873133dd6b2c2df45b0cced1b680 Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Sat, 7 Jun 2025 17:56:28 +0800 Subject: [PATCH 06/14] [Fix] adapt to main --- fastvideo/v1/fastvideo_args.py | 63 +++++++++++++++---- .../v1/pipelines/composed_pipeline_base.py | 27 ++++---- .../v1/training/distillation_pipeline.py | 10 +-- .../v1/training/wan_distillation_pipeline.py | 4 +- scripts/distill/distill_v1.sh | 2 + 5 files changed, 73 insertions(+), 33 deletions(-) diff --git a/fastvideo/v1/fastvideo_args.py b/fastvideo/v1/fastvideo_args.py index f010f3285..d8de97c2a 100644 --- a/fastvideo/v1/fastvideo_args.py +++ b/fastvideo/v1/fastvideo_args.py @@ -6,8 +6,11 @@ import dataclasses from contextlib import contextmanager from dataclasses import field +from enum import Enum from typing import Any, Callable, List, Optional, Tuple +import torch + from fastvideo.v1.configs.models import DiTConfig, EncoderConfig, VAEConfig from fastvideo.v1.logger import init_logger from fastvideo.v1.utils import FlexibleArgumentParser, StoreBoolean @@ -15,6 +18,13 @@ logger = init_logger(__name__) +class Mode(Enum): + """Enumeration for FastVideo execution modes.""" + INFERENCE = "inference" + TRAINING = "training" + DISTILL = "distill" + + def preprocess_text(prompt: str) -> str: return prompt @@ -34,7 +44,7 @@ class FastVideoArgs: # Distributed executor backend distributed_executor_backend: str = "mp" - mode: str = "inference" # Options: "inference", "training", "distill" + mode: Mode = Mode.INFERENCE # HuggingFace specific parameters trust_remote_code: bool = False @@ -111,15 +121,15 @@ class FastVideoArgs: @property def training_mode(self) -> bool: - return self.mode == "training" + return self.mode == Mode.TRAINING @property def distill_mode(self) -> bool: - return self.mode == "distill" + return self.mode == Mode.DISTILL @property def inference_mode(self) -> bool: - return self.mode == "inference" + return self.mode == Mode.INFERENCE def __post_init__(self): self.check_fastvideo_args() @@ -156,8 +166,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--mode", type=str, - default=FastVideoArgs.mode, - choices=["inference", "training", "distill"], + default=FastVideoArgs.mode.value, + choices=[mode.value for mode in Mode], help="The mode to use", ) @@ -371,9 +381,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs": - args.tp_size = args.tensor_parallel_size - args.sp_size = args.sequence_parallel_size - args.flow_shift = getattr(args, "shift", args.flow_shift) + assert getattr(args, 'model_path', None) is not None, "model_path must be set in args" + # Handle attribute mapping with safe getattr + if hasattr(args, 'tensor_parallel_size'): + args.tp_size = args.tensor_parallel_size + if hasattr(args, 'sequence_parallel_size'): + args.sp_size = args.sequence_parallel_size + if hasattr(args, 'shift'): + args.flow_shift = args.shift + elif hasattr(args, 'flow_shift'): + args.flow_shift = args.flow_shift # Get all fields from the dataclass attrs = [attr.name for attr in dataclasses.fields(cls)] @@ -388,6 +405,18 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs": kwargs[attr] = args.sequence_parallel_size elif attr == 'flow_shift' and hasattr(args, 'shift'): kwargs[attr] = args.shift + elif attr == 'mode': + # Convert string mode to Mode enum + mode_value = getattr(args, attr, None) + if mode_value: + if isinstance(mode_value, Mode): + kwargs[attr] = mode_value + else: + kwargs[attr] = Mode(mode_value) + else: + kwargs[attr] = Mode.INFERENCE + elif attr == 'device_str': + kwargs[attr] = getattr(args, 'device', None) or "cuda" if torch.cuda.is_available() else "cpu" # Use getattr with default value from the dataclass for potentially missing attributes else: default_value = getattr(cls, attr, None) @@ -587,9 +616,6 @@ class TrainingArgs(FastVideoArgs): # master_weight_type master_weight_type: str = "" - # For fast checking in LoRA pipeline - training_mode: bool = True - @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs": # Get all fields from the dataclass @@ -605,6 +631,19 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs": kwargs[attr] = args.sequence_parallel_size elif attr == 'flow_shift' and hasattr(args, 'shift'): kwargs[attr] = args.shift + elif attr == 'mode': + # Convert string mode to Mode enum + mode_value = getattr(args, attr, None) + if mode_value: + if isinstance(mode_value, Mode): + kwargs[attr] = mode_value + else: + kwargs[attr] = Mode(mode_value) + else: + kwargs[attr] = Mode.TRAINING # Default to training for TrainingArgs + elif attr == 'device_str': + kwargs[attr] = getattr(args, 'device', None) or "cuda" if torch.cuda.is_available() else "cpu" + # Use getattr with default value from the dataclass for potentially missing attributes else: default_value = getattr(cls, attr, None) if getattr(args, attr, default_value) is not None: diff --git a/fastvideo/v1/pipelines/composed_pipeline_base.py b/fastvideo/v1/pipelines/composed_pipeline_base.py index 29190352e..4465b7366 100644 --- a/fastvideo/v1/pipelines/composed_pipeline_base.py +++ b/fastvideo/v1/pipelines/composed_pipeline_base.py @@ -18,6 +18,10 @@ from fastvideo.v1.distributed import ( maybe_init_distributed_environment_and_model_parallel) from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.v1.distributed import (init_distributed_environment, + initialize_model_parallel, + model_parallel_is_initialized) +from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs, Mode from fastvideo.v1.logger import init_logger from fastvideo.v1.models.loader.component_loader import PipelineComponentLoader from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -94,7 +98,6 @@ def __init__(self, self.initialize_validation_pipeline(self.training_args) self.initialize_training_pipeline(self.training_args) - # TODO(jinzhe): discuss this if fastvideo_args.distill_mode: assert self.training_args is not None if self.training_args.log_validation: @@ -159,32 +162,32 @@ def from_pretrained(cls, config_args = shallow_asdict(config) config_args.update(kwargs) - if args.mode == "inference": - fastvideo_args = FastVideoArgs(model_path=model_path, - device_str=device or "cuda" if - torch.cuda.is_available() else "cpu", - **config_args) - fastvideo_args.model_path = model_path + args.model_path = model_path + # Handle both string mode and Mode enum values + mode_str = args.mode if isinstance(args.mode, str) else args.mode.value + + if mode_str == "inference": + fastvideo_args = FastVideoArgs.from_cli_args(args) for key, value in config_args.items(): setattr(fastvideo_args, key, value) - else: + + elif mode_str == "training" or mode_str == "distill": assert args is not None, "args must be provided for training mode" fastvideo_args = TrainingArgs.from_cli_args(args) - # TODO(will): fix this so that its not so ugly - fastvideo_args.model_path = model_path for key, value in config_args.items(): setattr(fastvideo_args, key, value) fastvideo_args.use_cpu_offload = False # make sure we are in training mode - fastvideo_args.mode = args.mode # we hijack the precision to be the master weight type so that the # model is loaded with the correct precision. Subsequently we will # use FSDP2's MixedPrecisionPolicy to set the precision for the # fwd, bwd, and other operations' precision. # fastvideo_args.precision = fastvideo_args.master_weight_type assert fastvideo_args.master_weight_type == 'fp32', 'only fp32 is supported for training' - # assert fastvideo_args.precision == 'fp32', 'only fp32 is supported for training' + else: + raise ValueError(f"Invalid mode: {mode_str}") + logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args) diff --git a/fastvideo/v1/training/distillation_pipeline.py b/fastvideo/v1/training/distillation_pipeline.py index fe08911b0..030699db5 100644 --- a/fastvideo/v1/training/distillation_pipeline.py +++ b/fastvideo/v1/training/distillation_pipeline.py @@ -17,7 +17,7 @@ from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset from fastvideo.v1.distributed import get_sp_group -from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs, Mode from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines import ComposedPipelineBase from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -150,8 +150,6 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): train_dataset = ParquetVideoTextDataset( fastvideo_args.data_path, batch_size=fastvideo_args.train_batch_size, - rank=self.rank, - world_size=self.world_size, cfg_rate=fastvideo_args.cfg, num_latent_t=fastvideo_args.num_latent_t) @@ -203,7 +201,7 @@ def distill_one_step(self, transformer, model_type, teacher_transformer, def log_validation(self, transformer, fastvideo_args, global_step): """Log validation results during training.""" - fastvideo_args.mode = "inference" + fastvideo_args.mode = Mode.INFERENCE fastvideo_args.use_cpu_offload = False if not fastvideo_args.log_validation: return @@ -218,8 +216,6 @@ def log_validation(self, transformer, fastvideo_args, global_step): validation_dataset = ParquetVideoTextDataset( fastvideo_args.validation_prompt_dir, batch_size=1, - rank=0, - world_size=1, cfg_rate=0, num_latent_t=fastvideo_args.num_latent_t) @@ -324,7 +320,7 @@ def log_validation(self, transformer, fastvideo_args, global_step): wandb.log(logs, step=global_step) # Re-enable gradients for training - fastvideo_args.mode = "distill" + fastvideo_args.mode = Mode.DISTILL transformer.requires_grad_(True) transformer.train() diff --git a/fastvideo/v1/training/wan_distillation_pipeline.py b/fastvideo/v1/training/wan_distillation_pipeline.py index 1dca8f931..e9bda39a2 100644 --- a/fastvideo/v1/training/wan_distillation_pipeline.py +++ b/fastvideo/v1/training/wan_distillation_pipeline.py @@ -9,7 +9,7 @@ from fastvideo.distill.solver import extract_into_tensor from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group -from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -60,7 +60,7 @@ def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): logger.info("Initializing validation pipeline...") args_copy = deepcopy(fastvideo_args) - args_copy.mode = "inference" + args_copy.mode = Mode.INFERENCE args_copy.vae_config.load_encoder = False validation_pipeline = WanValidationPipeline.from_pretrained( fastvideo_args.model_path, args=args_copy) diff --git a/scripts/distill/distill_v1.sh b/scripts/distill/distill_v1.sh index bfcb1af35..37c93f8d5 100755 --- a/scripts/distill/distill_v1.sh +++ b/scripts/distill/distill_v1.sh @@ -21,6 +21,8 @@ torchrun --nnodes 1 --nproc_per_node $num_gpus\ --train_batch_size=1 \ --num_latent_t 4 \ --sp_size $num_gpus \ + --dp_size $num_gpus \ + --dp_shards $num_gpus \ --train_sp_batch_size 1 \ --dataloader_num_workers $num_gpus \ --gradient_accumulation_steps=1 \ From 2dea92803d6568d0c9015a1e42b5682a1b8fa721 Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Sat, 7 Jun 2025 18:12:15 +0800 Subject: [PATCH 07/14] [Fmt] fmt code --- fastvideo/v1/dataset/parquet_datasets.py | 1 - fastvideo/v1/fastvideo_args.py | 14 ++-- .../v1/pipelines/composed_pipeline_base.py | 5 +- .../v1/training/distillation_pipeline.py | 16 +++-- .../v1/training/wan_distillation_pipeline.py | 71 +++++++++++-------- 5 files changed, 65 insertions(+), 42 deletions(-) diff --git a/fastvideo/v1/dataset/parquet_datasets.py b/fastvideo/v1/dataset/parquet_datasets.py index 772512fd3..0732ef4ae 100644 --- a/fastvideo/v1/dataset/parquet_datasets.py +++ b/fastvideo/v1/dataset/parquet_datasets.py @@ -192,7 +192,6 @@ def get_validation_negative_prompt( lat = lat[:, self.rank_in_sp_group, :, :, :] return lat, emb, mask, info - def __len__(self): if self.local_indices is None: try: diff --git a/fastvideo/v1/fastvideo_args.py b/fastvideo/v1/fastvideo_args.py index d8de97c2a..dac791a56 100644 --- a/fastvideo/v1/fastvideo_args.py +++ b/fastvideo/v1/fastvideo_args.py @@ -381,7 +381,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs": - assert getattr(args, 'model_path', None) is not None, "model_path must be set in args" + assert getattr(args, 'model_path', + None) is not None, "model_path must be set in args" # Handle attribute mapping with safe getattr if hasattr(args, 'tensor_parallel_size'): args.tp_size = args.tensor_parallel_size @@ -416,7 +417,9 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs": else: kwargs[attr] = Mode.INFERENCE elif attr == 'device_str': - kwargs[attr] = getattr(args, 'device', None) or "cuda" if torch.cuda.is_available() else "cpu" + kwargs[attr] = getattr( + args, 'device', + None) or "cuda" if torch.cuda.is_available() else "cpu" # Use getattr with default value from the dataclass for potentially missing attributes else: default_value = getattr(cls, attr, None) @@ -640,9 +643,12 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs": else: kwargs[attr] = Mode(mode_value) else: - kwargs[attr] = Mode.TRAINING # Default to training for TrainingArgs + kwargs[ + attr] = Mode.TRAINING # Default to training for TrainingArgs elif attr == 'device_str': - kwargs[attr] = getattr(args, 'device', None) or "cuda" if torch.cuda.is_available() else "cpu" + kwargs[attr] = getattr( + args, 'device', + None) or "cuda" if torch.cuda.is_available() else "cpu" # Use getattr with default value from the dataclass for potentially missing attributes else: default_value = getattr(cls, attr, None) diff --git a/fastvideo/v1/pipelines/composed_pipeline_base.py b/fastvideo/v1/pipelines/composed_pipeline_base.py index 4465b7366..e9c5fc11a 100644 --- a/fastvideo/v1/pipelines/composed_pipeline_base.py +++ b/fastvideo/v1/pipelines/composed_pipeline_base.py @@ -21,7 +21,7 @@ from fastvideo.v1.distributed import (init_distributed_environment, initialize_model_parallel, model_parallel_is_initialized) -from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs, Mode +from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.loader.component_loader import PipelineComponentLoader from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -165,7 +165,7 @@ def from_pretrained(cls, args.model_path = model_path # Handle both string mode and Mode enum values mode_str = args.mode if isinstance(args.mode, str) else args.mode.value - + if mode_str == "inference": fastvideo_args = FastVideoArgs.from_cli_args(args) for key, value in config_args.items(): @@ -188,6 +188,7 @@ def from_pretrained(cls, else: raise ValueError(f"Invalid mode: {mode_str}") + fastvideo_args.check_fastvideo_args() logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args) diff --git a/fastvideo/v1/training/distillation_pipeline.py b/fastvideo/v1/training/distillation_pipeline.py index 030699db5..c07a22570 100644 --- a/fastvideo/v1/training/distillation_pipeline.py +++ b/fastvideo/v1/training/distillation_pipeline.py @@ -6,18 +6,18 @@ import numpy as np import torch import torchvision -import wandb from diffusers.optimization import get_scheduler from einops import rearrange from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ShardingStrategy from torchdata.stateful_dataloader import StatefulDataLoader +import wandb from fastvideo.distill.solver import EulerSolver from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset from fastvideo.v1.distributed import get_sp_group -from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs, Mode +from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines import ComposedPipelineBase from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -54,6 +54,7 @@ def reshard_fsdp(model): if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD: torch.distributed.fsdp._runtime_utils._reshard(m, m._handle, True) + class DistillationPipeline(ComposedPipelineBase, ABC): """ A pipeline for distillation training. All distillation pipelines should inherit from this class. @@ -77,10 +78,12 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): # Initialize teacher model without deepcopy to avoid FSDP issues logger.info("Creating teacher model...") - from fastvideo.v1.models.loader.component_loader import TransformerLoader + from fastvideo.v1.models.loader.component_loader import ( + TransformerLoader) teacher_loader = TransformerLoader() transformer_path = os.path.join(self.model_path, "transformer") - self.teacher_transformer = teacher_loader.load(transformer_path, "", fastvideo_args) + self.teacher_transformer = teacher_loader.load(transformer_path, "", + fastvideo_args) self.teacher_transformer.requires_grad_(False) self.teacher_transformer.eval() logger.info("Teacher model initialized") @@ -89,7 +92,8 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): if fastvideo_args.use_ema: logger.info("Creating EMA model...") ema_loader = TransformerLoader() - self.ema_transformer = ema_loader.load(transformer_path, "", fastvideo_args) + self.ema_transformer = ema_loader.load(transformer_path, "", + fastvideo_args) self.ema_transformer.requires_grad_(False) self.ema_transformer.eval() logger.info("EMA model initialized") @@ -326,5 +330,3 @@ def log_validation(self, transformer, fastvideo_args, global_step): gc.collect() torch.cuda.empty_cache() - - diff --git a/fastvideo/v1/training/wan_distillation_pipeline.py b/fastvideo/v1/training/wan_distillation_pipeline.py index e9bda39a2..8a1199087 100644 --- a/fastvideo/v1/training/wan_distillation_pipeline.py +++ b/fastvideo/v1/training/wan_distillation_pipeline.py @@ -4,23 +4,25 @@ from copy import deepcopy import torch -import wandb from tqdm.auto import tqdm +import wandb from fastvideo.distill.solver import extract_into_tensor from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch -from fastvideo.v1.training.training_utils import ( - clip_grad_norm_while_handling_failing_dtensor_cases, - save_checkpoint, normalize_dit_input) from fastvideo.v1.pipelines.wan.wan_pipeline import WanValidationPipeline -from fastvideo.v1.training.distillation_pipeline import DistillationPipeline, reshard_fsdp +from fastvideo.v1.training.distillation_pipeline import (DistillationPipeline, + reshard_fsdp) +from fastvideo.v1.training.training_utils import ( + clip_grad_norm_while_handling_failing_dtensor_cases, normalize_dit_input, + save_checkpoint) logger = init_logger(__name__) + def get_norm(model_pred, norms, gradient_accumulation_steps): """Calculate and aggregate model prediction norms.""" fro_norm = ( @@ -44,6 +46,7 @@ def get_norm(model_pred, norms, gradient_accumulation_steps): norms["absolute mean"] += absolute_mean.item() norms["absolute max"] += absolute_max.item() + class WanDistillationPipeline(DistillationPipeline): """ A distillation pipeline for Wan. @@ -124,15 +127,14 @@ def distill_one_step( noise = torch.randn_like(latents) indices = torch.randint(0, - num_euler_timesteps, (batch_size, ), - device=latents.device).long() + num_euler_timesteps, (batch_size, ), + device=latents.device).long() if sp_size > 1: self.sp_group.broadcast(indices, src=0) # Add noise according to flow matching - sigmas = extract_into_tensor(solver.sigmas, indices, - latents.shape) + sigmas = extract_into_tensor(solver.sigmas, indices, latents.shape) sigmas_prev = extract_into_tensor(solver.sigmas_prev, indices, latents.shape) @@ -186,16 +188,23 @@ def distill_one_step( # Get teacher model prediction on unconditional embedding with torch.autocast("cuda", dtype=torch.bfloat16): input_kwargs = { - "hidden_states": noisy_model_input, - "encoder_hidden_states": uncond_prompt_embed.unsqueeze(0).expand( - batch_size, -1, -1), - "timestep": timesteps, - "encoder_attention_mask": uncond_prompt_mask.unsqueeze(0).expand(batch_size, -1), - "return_dict": False, + "hidden_states": + noisy_model_input, + "encoder_hidden_states": + uncond_prompt_embed.unsqueeze(0).expand( + batch_size, -1, -1), + "timestep": + timesteps, + "encoder_attention_mask": + uncond_prompt_mask.unsqueeze(0).expand( + batch_size, -1), + "return_dict": + False, } with set_forward_context(current_timestep=timesteps, attn_metadata=None): - uncond_teacher_output = teacher_transformer(**input_kwargs)[0] + uncond_teacher_output = teacher_transformer( + **input_kwargs)[0] teacher_output = uncond_teacher_output + distill_cfg * ( cond_teacher_output - uncond_teacher_output) x_prev = solver.euler_step(noisy_model_input, teacher_output, @@ -305,19 +314,24 @@ def forward( uncond_prompt_mask = self.uncond_prompt_mask # Train! - total_batch_size = (self.world_size * self.training_args.gradient_accumulation_steps / - self.training_args.sp_size * self.training_args.train_sp_batch_size) + total_batch_size = (self.world_size * + self.training_args.gradient_accumulation_steps / + self.training_args.sp_size * + self.training_args.train_sp_batch_size) logger.info("***** Running distillation training *****") logger.info(f" Resume training from step {init_steps}") logger.info( - f" Instantaneous batch size per device = {self.training_args.train_batch_size}") + f" Instantaneous batch size per device = {self.training_args.train_batch_size}" + ) logger.info( f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" ) logger.info( f" Gradient Accumulation steps = {self.training_args.gradient_accumulation_steps}" ) - logger.info(f" Total optimization steps = {self.training_args.max_train_steps}") + logger.info( + f" Total optimization steps = {self.training_args.max_train_steps}" + ) logger.info( f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B" ) @@ -354,12 +368,13 @@ def get_num_phases(multi_phased_distill_schedule, step): return int(phase) return int(phase) - for step in range(init_steps + 1, self.training_args.max_train_steps + 1): + for step in range(init_steps + 1, + self.training_args.max_train_steps + 1): start_time = time.perf_counter() assert self.training_args.multi_phased_distill_schedule is not None - num_phases = get_num_phases(self.training_args.multi_phased_distill_schedule, - step) + num_phases = get_num_phases( + self.training_args.multi_phased_distill_schedule, step) try: loss, grad_norm, pred_norm = self.distill_one_step( self.transformer, @@ -407,7 +422,6 @@ def get_num_phases(multi_phased_distill_schedule, step): step -= 1 continue - if self.rank <= 0: wandb.log( { @@ -441,10 +455,10 @@ def get_num_phases(multi_phased_distill_schedule, step): else: if self.training_args.use_ema: save_checkpoint(self.ema_transformer, self.rank, - self.training_args.output_dir, step) + self.training_args.output_dir, step) else: save_checkpoint(self.transformer, self.rank, - self.training_args.output_dir, step) + self.training_args.output_dir, step) self.sp_group.barrier() if self.training_args.log_validation and step % self.training_args.validation_steps == 0: @@ -454,8 +468,9 @@ def get_num_phases(multi_phased_distill_schedule, step): if self.training_args.use_lora: raise NotImplementedError("LoRA is not supported now") else: - save_checkpoint(self.transformer, self.rank, self.training_args.output_dir, - self.training_args.max_train_steps) + save_checkpoint(self.transformer, self.rank, + self.training_args.output_dir, + self.training_args.max_train_steps) if get_sp_group(): cleanup_dist_env_and_memory() From a87c7d70ccbe7e5291e1af8c82d9bb3400b6d31e Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Tue, 10 Jun 2025 19:41:52 +0800 Subject: [PATCH 08/14] pre-commit --- .../v1/pipelines/composed_pipeline_base.py | 23 +++-- .../v1/training/distillation_pipeline.py | 32 +++---- fastvideo/v1/training/training_pipeline.py | 4 +- .../v1/training/wan_distillation_pipeline.py | 84 +++++++++---------- .../v1/training/wan_training_pipeline.py | 6 +- 5 files changed, 78 insertions(+), 71 deletions(-) diff --git a/fastvideo/v1/pipelines/composed_pipeline_base.py b/fastvideo/v1/pipelines/composed_pipeline_base.py index e9c5fc11a..e2cb961aa 100644 --- a/fastvideo/v1/pipelines/composed_pipeline_base.py +++ b/fastvideo/v1/pipelines/composed_pipeline_base.py @@ -9,6 +9,7 @@ import os from abc import ABC, abstractmethod from copy import deepcopy +from enum import Enum from typing import Any, Dict, List, Optional, Union, cast import torch @@ -106,7 +107,7 @@ def __init__(self, self.initialize_pipeline(fastvideo_args) - if not fastvideo_args.training_mode: + if fastvideo_args.inference_mode: logger.info("Creating pipeline stages...") self.create_pipeline_stages(fastvideo_args) @@ -119,7 +120,7 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs): "if log_validation is True, the pipeline must implement this method" ) - def initialize_distillation_pipeline(self, fastvideo_args: FastVideoArgs): + def initialize_distillation_pipeline(self, training_args: TrainingArgs): raise NotImplementedError( "if distill_mode is True, the pipeline must implement this method") @@ -162,29 +163,37 @@ def from_pretrained(cls, config_args = shallow_asdict(config) config_args.update(kwargs) - args.model_path = model_path # Handle both string mode and Mode enum values - mode_str = args.mode if isinstance(args.mode, str) else args.mode.value + mode_str: str | Enum = getattr( + args, 'mode', "inference") if args is not None else "inference" + if hasattr(mode_str, 'value'): + mode_str = mode_str.value + mode_str = str(mode_str) if mode_str == "inference": - fastvideo_args = FastVideoArgs.from_cli_args(args) + fastvideo_args = FastVideoArgs(model_path=model_path, **config_args) + + fastvideo_args.model_path = model_path for key, value in config_args.items(): setattr(fastvideo_args, key, value) - elif mode_str == "training" or mode_str == "distill": assert args is not None, "args must be provided for training mode" fastvideo_args = TrainingArgs.from_cli_args(args) + # TODO(will): fix this so that its not so ugly + fastvideo_args.model_path = model_path for key, value in config_args.items(): setattr(fastvideo_args, key, value) fastvideo_args.use_cpu_offload = False - # make sure we are in training mode + # make sure we are in training mode - note: inference_mode is read-only, + # so we don't set it directly here as it's determined by the mode # we hijack the precision to be the master weight type so that the # model is loaded with the correct precision. Subsequently we will # use FSDP2's MixedPrecisionPolicy to set the precision for the # fwd, bwd, and other operations' precision. # fastvideo_args.precision = fastvideo_args.master_weight_type assert fastvideo_args.master_weight_type == 'fp32', 'only fp32 is supported for training' + # assert fastvideo_args.precision == 'fp32', 'only fp32 is supported for training' else: raise ValueError(f"Invalid mode: {mode_str}") diff --git a/fastvideo/v1/training/distillation_pipeline.py b/fastvideo/v1/training/distillation_pipeline.py index c07a22570..5e230cf53 100644 --- a/fastvideo/v1/training/distillation_pipeline.py +++ b/fastvideo/v1/training/distillation_pipeline.py @@ -1,6 +1,7 @@ import gc import os from abc import ABC, abstractmethod +from typing import List, Optional import imageio import numpy as np @@ -26,7 +27,10 @@ # from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 -def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): +def linear_quadratic_schedule( + num_steps: int, + threshold_noise: float, + linear_steps: Optional[int] = None) -> List[float]: if linear_steps is None: linear_steps = num_steps // 2 linear_sigma_schedule = [ @@ -48,7 +52,7 @@ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): return sigma_schedule -def reshard_fsdp(model): +def reshard_fsdp(model: torch.nn.Module) -> None: """Reshard FSDP model for EMA updates.""" for m in FSDP.fsdp_modules(model): if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD: @@ -60,6 +64,7 @@ class DistillationPipeline(ComposedPipelineBase, ABC): A pipeline for distillation training. All distillation pipelines should inherit from this class. """ _required_config_modules = ["scheduler", "transformer"] + validation_pipeline: ComposedPipelineBase def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): logger.info("Initializing distillation pipeline...") @@ -104,6 +109,7 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): assert noise_scheduler is not None # Initialize solver for distillation + sigmas: torch.Tensor | List[float] = [] if fastvideo_args.scheduler_type == "pcm_linear_quadratic": linear_steps = int(noise_scheduler.config.num_train_timesteps * fastvideo_args.linear_range) @@ -112,10 +118,12 @@ def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): fastvideo_args.linear_quadratic_threshold, linear_steps, ) - sigmas = torch.tensor(sigmas).to(dtype=torch.float32) else: sigmas = noise_scheduler.sigmas + if isinstance(sigmas, list): + sigmas = torch.tensor(sigmas).to(dtype=torch.float32) + self.solver = EulerSolver( sigmas.numpy(), noise_scheduler.config.num_train_timesteps, @@ -203,7 +211,8 @@ def distill_one_step(self, transformer, model_type, teacher_transformer, raise NotImplementedError( "Distillation pipeline must implement this method") - def log_validation(self, transformer, fastvideo_args, global_step): + @torch.no_grad() + def _log_validation(self, transformer, fastvideo_args, global_step): """Log validation results during training.""" fastvideo_args.mode = Mode.INFERENCE fastvideo_args.use_cpu_offload = False @@ -220,8 +229,9 @@ def log_validation(self, transformer, fastvideo_args, global_step): validation_dataset = ParquetVideoTextDataset( fastvideo_args.validation_prompt_dir, batch_size=1, - cfg_rate=0, - num_latent_t=fastvideo_args.num_latent_t) + cfg_rate=fastvideo_args.cfg, + num_latent_t=fastvideo_args.num_latent_t, + validation=True) validation_dataloader = StatefulDataLoader(validation_dataset, batch_size=1, @@ -231,21 +241,13 @@ def log_validation(self, transformer, fastvideo_args, global_step): pin_memory=True, drop_last=False) - transformer.requires_grad_(False) - for p in transformer.parameters(): - p.requires_grad = False transformer.eval() - # Add the transformer to the validation pipeline - self.validation_pipeline.add_module("transformer", transformer) - self.validation_pipeline.latent_preparation_stage.transformer = transformer - self.validation_pipeline.denoising_stage.transformer = transformer - # Process validation prompts videos = [] captions = [] for _, embeddings, masks, infos in validation_dataloader: - logger.info(f"infos: {infos}") + logger.info("infos: %s", infos) caption = infos['caption'] captions.append(caption) prompt_embeds = embeddings.to(fastvideo_args.device) diff --git a/fastvideo/v1/training/training_pipeline.py b/fastvideo/v1/training/training_pipeline.py index eca1ebc8a..a9cd37db1 100644 --- a/fastvideo/v1/training/training_pipeline.py +++ b/fastvideo/v1/training/training_pipeline.py @@ -19,7 +19,7 @@ from fastvideo.v1.dataset import build_parquet_map_style_dataloader from fastvideo.v1.distributed import (get_sp_group, get_torch_device, get_world_group) -from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines import ComposedPipelineBase @@ -143,7 +143,7 @@ def train_one_step(self, transformer, model_type, optimizer, lr_scheduler, @torch.no_grad() def _log_validation(self, transformer, training_args, global_step) -> None: assert training_args is not None - training_args.inference_mode = True + training_args.mode = Mode.INFERENCE training_args.use_cpu_offload = False if not training_args.log_validation: return diff --git a/fastvideo/v1/training/wan_distillation_pipeline.py b/fastvideo/v1/training/wan_distillation_pipeline.py index 8a1199087..53358782d 100644 --- a/fastvideo/v1/training/wan_distillation_pipeline.py +++ b/fastvideo/v1/training/wan_distillation_pipeline.py @@ -2,6 +2,7 @@ import time from collections import deque from copy import deepcopy +from typing import Dict import torch from tqdm.auto import tqdm @@ -23,7 +24,8 @@ logger = init_logger(__name__) -def get_norm(model_pred, norms, gradient_accumulation_steps): +def get_norm(model_pred: torch.Tensor, norms: Dict[str, float], + gradient_accumulation_steps: int) -> None: """Calculate and aggregate model prediction norms.""" fro_norm = ( torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore @@ -66,7 +68,10 @@ def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): args_copy.mode = Mode.INFERENCE args_copy.vae_config.load_encoder = False validation_pipeline = WanValidationPipeline.from_pretrained( - fastvideo_args.model_path, args=args_copy) + fastvideo_args.model_path, + args=None, + mode=Mode.INFERENCE, + loaded_modules={"transformer": self.get_module("transformer")}) self.validation_pipeline = validation_pipeline @@ -95,11 +100,7 @@ def distill_one_step( pred_decay_weight, pred_decay_type, hunyuan_teacher_disable_cfg, - weighting_scheme, - logit_mean, - logit_std, - mode_scale, - ): + ) -> tuple[float, float, Dict[str, float]]: """Perform one step of distillation training.""" total_loss = 0.0 optimizer.zero_grad() @@ -170,17 +171,16 @@ def distill_one_step( noisy_model_input, model_pred, indices, multiphase) # Get teacher model prediction - with torch.no_grad(): - with torch.autocast("cuda", dtype=torch.bfloat16): - with set_forward_context(current_timestep=timesteps, - attn_metadata=None): - cond_teacher_output = teacher_transformer( - noisy_model_input, - encoder_hidden_states, - timesteps, - encoder_attention_mask, - return_dict=False, - )[0].float() + with torch.no_grad(), torch.autocast( + "cuda", dtype=torch.bfloat16), set_forward_context( + current_timestep=timesteps, attn_metadata=None): + cond_teacher_output = teacher_transformer( + noisy_model_input, + encoder_hidden_states, + timesteps, + encoder_attention_mask, + return_dict=False, + )[0].float() if not_apply_cfg_solver: uncond_teacher_output = cond_teacher_output @@ -313,31 +313,30 @@ def forward( uncond_prompt_embed = self.uncond_prompt_embed uncond_prompt_mask = self.uncond_prompt_mask - # Train! + assert self.training_args.sp_size is not None + assert self.training_args.gradient_accumulation_steps is not None total_batch_size = (self.world_size * self.training_args.gradient_accumulation_steps / self.training_args.sp_size * self.training_args.train_sp_batch_size) logger.info("***** Running distillation training *****") - logger.info(f" Resume training from step {init_steps}") - logger.info( - f" Instantaneous batch size per device = {self.training_args.train_batch_size}" - ) + logger.info(" Resume training from step %s", init_steps) + logger.info(" Instantaneous batch size per device = %s", + self.training_args.train_batch_size) logger.info( - f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}" - ) + " Total train batch size (w. data & sequence parallel, accumulation) = %s", + total_batch_size) + logger.info(" Gradient Accumulation steps = %s", + self.training_args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %s", + self.training_args.max_train_steps) logger.info( - f" Gradient Accumulation steps = {self.training_args.gradient_accumulation_steps}" - ) - logger.info( - f" Total optimization steps = {self.training_args.max_train_steps}" - ) - logger.info( - f" Total training parameters per FSDP shard = {sum(p.numel() for p in self.transformer.parameters() if p.requires_grad) / 1e9} B" - ) - logger.info( - f" Master weight dtype: {self.transformer.parameters().__next__().dtype}" - ) + " Total training parameters per FSDP shard = %s B", + sum(p.numel() + for p in self.transformer.parameters() if p.requires_grad) / + 1e9) + logger.info(" Master weight dtype: %s", + self.transformer.parameters().__next__().dtype) # Potentially load in the weights and states from a previous save if self.training_args.resume_from_checkpoint: @@ -352,13 +351,14 @@ def forward( ) loader_iter = iter(train_dataloader) - step_times = deque(maxlen=100) + step_times: deque[float] = deque(maxlen=100) # Skip steps if resuming for i in range(init_steps): next(loader_iter) - def get_num_phases(multi_phased_distill_schedule, step): + def get_num_phases(multi_phased_distill_schedule: str, + step: int) -> int: # step-phase,step-phase multi_phases = multi_phased_distill_schedule.split(",") phase = multi_phases[-1].split("-")[-1] @@ -400,10 +400,6 @@ def get_num_phases(multi_phased_distill_schedule, step): self.training_args.pred_decay_weight, self.training_args.pred_decay_type, self.training_args.hunyuan_teacher_disable_cfg, - self.training_args.weighting_scheme, - self.training_args.logit_mean, - self.training_args.logit_std, - self.training_args.mode_scale, ) step_time = time.perf_counter() - start_time @@ -462,7 +458,7 @@ def get_num_phases(multi_phased_distill_schedule, step): self.sp_group.barrier() if self.training_args.log_validation and step % self.training_args.validation_steps == 0: - self.log_validation(self.transformer, self.training_args, step) + self._log_validation(self.transformer, self.training_args, step) # Final checkpoint if self.training_args.use_lora: @@ -476,7 +472,7 @@ def get_num_phases(multi_phased_distill_schedule, step): cleanup_dist_env_and_memory() -def main(args): +def main(args) -> None: logger.info("Starting distillation pipeline...") pipeline = WanDistillationPipeline.from_pretrained( diff --git a/fastvideo/v1/training/wan_training_pipeline.py b/fastvideo/v1/training/wan_training_pipeline.py index a879dd2d6..6712fc4b9 100644 --- a/fastvideo/v1/training/wan_training_pipeline.py +++ b/fastvideo/v1/training/wan_training_pipeline.py @@ -12,7 +12,7 @@ from fastvideo.v1.distributed import (cleanup_dist_env_and_memory, get_sp_group, get_torch_device, get_world_group) -from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger from fastvideo.v1.models.schedulers.scheduling_flow_unipc_multistep import ( @@ -53,12 +53,12 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs): logger.info("Initializing validation pipeline...") args_copy = deepcopy(training_args) - args_copy.inference_mode = True + args_copy.mode = Mode.INFERENCE args_copy.vae_config.load_encoder = False validation_pipeline = WanValidationPipeline.from_pretrained( training_args.model_path, args=None, - inference_mode=True, + mode=Mode.INFERENCE, loaded_modules={"transformer": self.get_module("transformer")}, tp_size=training_args.tp_size, sp_size=training_args.sp_size, From c997709f1f96b226c641891261d5afab35eef386 Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Wed, 11 Jun 2025 11:10:38 +0800 Subject: [PATCH 09/14] fix --- fastvideo/distill/solver.py | 5 ----- fastvideo/v1/models/loader/component_loader.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/fastvideo/distill/solver.py b/fastvideo/distill/solver.py index 1082bc4ac..07c0a9765 100644 --- a/fastvideo/distill/solver.py +++ b/fastvideo/distill/solver.py @@ -248,11 +248,6 @@ def __init__(self, sigmas, timesteps=1000, euler_timesteps=50): self.sigmas = sigmas[self.euler_timesteps] self.sigmas_prev = np.asarray([sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist()) # either use sigma0 or 0 - print(f"sigmas: {sigmas}") - print(f"euler_timesteps: {self.euler_timesteps}") - print(f"sigmas: {self.sigmas}") - print(f"sigmas_prev: {self.sigmas_prev}") - self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long() self.euler_timesteps_prev = torch.from_numpy(self.euler_timesteps_prev).long() self.sigmas = torch.from_numpy(self.sigmas) diff --git a/fastvideo/v1/models/loader/component_loader.py b/fastvideo/v1/models/loader/component_loader.py index c8908420b..8d50a0a42 100644 --- a/fastvideo/v1/models/loader/component_loader.py +++ b/fastvideo/v1/models/loader/component_loader.py @@ -367,7 +367,7 @@ class TransformerLoader(ComponentLoader): def load(self, model_path: str, architecture: str, fastvideo_args: FastVideoArgs): """Load the transformer based on the model path, architecture, and inference args.""" - print(f"Loading transformer from {model_path}") + logger.info("Loading transformer from %s", model_path) config = get_diffusers_config(model=model_path) hf_config = deepcopy(config) cls_name = config.pop("_class_name") From 44dbbde6a9b4020b77cc407665b0e334bdf90881 Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Wed, 11 Jun 2025 11:52:08 +0800 Subject: [PATCH 10/14] pre-commit --- fastvideo/v1/dataset/parquet_datasets.py | 1 - fastvideo/v1/pipelines/composed_pipeline_base.py | 4 ---- fastvideo/v1/training/distillation_pipeline.py | 3 ++- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/fastvideo/v1/dataset/parquet_datasets.py b/fastvideo/v1/dataset/parquet_datasets.py index 0732ef4ae..ffc89bcd0 100644 --- a/fastvideo/v1/dataset/parquet_datasets.py +++ b/fastvideo/v1/dataset/parquet_datasets.py @@ -405,7 +405,6 @@ def _process_row(self, row) -> Dict[str, Any]: model_path=VAE_PATH, vae_config=WanVAEConfig(load_encoder=False), vae_precision="fp32") - fastvideo_args.device = device vae_loader = VAELoader() vae = vae_loader.load(model_path=VAE_PATH, architecture="", diff --git a/fastvideo/v1/pipelines/composed_pipeline_base.py b/fastvideo/v1/pipelines/composed_pipeline_base.py index e2cb961aa..bd12282ef 100644 --- a/fastvideo/v1/pipelines/composed_pipeline_base.py +++ b/fastvideo/v1/pipelines/composed_pipeline_base.py @@ -19,10 +19,6 @@ from fastvideo.v1.distributed import ( maybe_init_distributed_environment_and_model_parallel) from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs -from fastvideo.v1.distributed import (init_distributed_environment, - initialize_model_parallel, - model_parallel_is_initialized) -from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.loader.component_loader import PipelineComponentLoader from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch diff --git a/fastvideo/v1/training/distillation_pipeline.py b/fastvideo/v1/training/distillation_pipeline.py index 5e230cf53..c995a36db 100644 --- a/fastvideo/v1/training/distillation_pipeline.py +++ b/fastvideo/v1/training/distillation_pipeline.py @@ -18,6 +18,7 @@ from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset from fastvideo.v1.distributed import get_sp_group +from fastvideo.v1.distributed.parallel_state import get_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines import ComposedPipelineBase @@ -68,7 +69,7 @@ class DistillationPipeline(ComposedPipelineBase, ABC): def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): logger.info("Initializing distillation pipeline...") - self.device = fastvideo_args.device + self.device = get_torch_device() self.sp_group = get_sp_group() self.world_size = self.sp_group.world_size self.rank = self.sp_group.rank From 18314c09b922b66bad83480c42aaeca1b7681b3d Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Wed, 11 Jun 2025 12:00:27 +0800 Subject: [PATCH 11/14] pre-commit --- fastvideo/v1/training/distillation_pipeline.py | 2 +- fastvideo/v1/training/wan_distillation_pipeline.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastvideo/v1/training/distillation_pipeline.py b/fastvideo/v1/training/distillation_pipeline.py index c995a36db..4c3e4e07f 100644 --- a/fastvideo/v1/training/distillation_pipeline.py +++ b/fastvideo/v1/training/distillation_pipeline.py @@ -7,13 +7,13 @@ import numpy as np import torch import torchvision +import wandb from diffusers.optimization import get_scheduler from einops import rearrange from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ShardingStrategy from torchdata.stateful_dataloader import StatefulDataLoader -import wandb from fastvideo.distill.solver import EulerSolver from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset diff --git a/fastvideo/v1/training/wan_distillation_pipeline.py b/fastvideo/v1/training/wan_distillation_pipeline.py index 53358782d..2dbc96891 100644 --- a/fastvideo/v1/training/wan_distillation_pipeline.py +++ b/fastvideo/v1/training/wan_distillation_pipeline.py @@ -5,9 +5,9 @@ from typing import Dict import torch +import wandb from tqdm.auto import tqdm -import wandb from fastvideo.distill.solver import extract_into_tensor from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs From 50b0d8b0c2c125c123741abc3ed8b46c7f4f290a Mon Sep 17 00:00:00 2001 From: Peiyuan Zhang Date: Thu, 12 Jun 2025 20:19:47 +0000 Subject: [PATCH 12/14] precommit --- .../v1/pipelines/preprocess/v1_preprocess.py | 54 +++++++++---------- .../v1/training/distillation_pipeline.py | 2 +- .../v1/training/wan_distillation_pipeline.py | 2 +- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/fastvideo/v1/pipelines/preprocess/v1_preprocess.py b/fastvideo/v1/pipelines/preprocess/v1_preprocess.py index dfad00e11..fe798befd 100644 --- a/fastvideo/v1/pipelines/preprocess/v1_preprocess.py +++ b/fastvideo/v1/pipelines/preprocess/v1_preprocess.py @@ -1,24 +1,24 @@ import argparse -import json -import os - -import torch -import torch.distributed as dist +from fastvideo import PipelineConfig +from fastvideo.v1.configs.models.vaes import WanVAEConfig +from fastvideo.v1.distributed import ( + get_world_size, maybe_init_distributed_environment_and_model_parallel) +from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger +from fastvideo.v1.pipelines.preprocess.preprocess_pipeline_i2v import ( + PreprocessPipeline_I2V) +from fastvideo.v1.pipelines.preprocess.preprocess_pipeline_t2v import ( + PreprocessPipeline_T2V) from fastvideo.v1.utils import maybe_download_model, shallow_asdict -from fastvideo.v1.distributed import maybe_init_distributed_environment_and_model_parallel, get_world_size -from fastvideo.v1.fastvideo_args import FastVideoArgs -from fastvideo.v1.configs.models.vaes import WanVAEConfig -from fastvideo import PipelineConfig -from fastvideo.v1.pipelines.preprocess.preprocess_pipeline_i2v import PreprocessPipeline_I2V -from fastvideo.v1.pipelines.preprocess.preprocess_pipeline_t2v import PreprocessPipeline_T2V logger = init_logger(__name__) -def main(args): + +def main(args) -> None: args.model_path = maybe_download_model(args.model_path) - maybe_init_distributed_environment_and_model_parallel(args.tp_size, args.sp_size) + maybe_init_distributed_environment_and_model_parallel( + args.tp_size, args.sp_size) pipeline_config = PipelineConfig.from_pretrained(args.model_path) kwargs = { @@ -65,18 +65,15 @@ def main(args): default=8, help="Batch size (per device) for the training dataloader.", ) - parser.add_argument( - "--samples_per_file", - type=int, - default=64 - ) - parser.add_argument( - "--flush_frequency", - type=int, - default=256, - help="how often to save to parquet files" - ) - parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.") + parser.add_argument("--samples_per_file", type=int, default=64) + parser.add_argument("--flush_frequency", + type=int, + default=256, + help="how often to save to parquet files") + parser.add_argument("--num_latent_t", + type=int, + default=28, + help="Number of latent timesteps.") parser.add_argument("--max_height", type=int, default=480) parser.add_argument("--max_width", type=int, default=848) parser.add_argument("--video_length_tolerance_range", type=int, default=2.0) @@ -90,14 +87,17 @@ def main(args): parser.add_argument("--speed_factor", type=float, default=1.0) parser.add_argument("--drop_short_ratio", type=float, default=1.0) # text encoder & vae & diffusion model - parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--text_encoder_name", + type=str, + default="google/t5-v1_1-xxl") parser.add_argument("--cache_dir", type=str, default="./cache_dir") parser.add_argument("--cfg", type=float, default=0.0) parser.add_argument( "--output_dir", type=str, default=None, - help="The output directory where the model predictions and checkpoints will be written.", + help= + "The output directory where the model predictions and checkpoints will be written.", ) args = parser.parse_args() diff --git a/fastvideo/v1/training/distillation_pipeline.py b/fastvideo/v1/training/distillation_pipeline.py index 4c3e4e07f..c995a36db 100644 --- a/fastvideo/v1/training/distillation_pipeline.py +++ b/fastvideo/v1/training/distillation_pipeline.py @@ -7,13 +7,13 @@ import numpy as np import torch import torchvision -import wandb from diffusers.optimization import get_scheduler from einops import rearrange from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ShardingStrategy from torchdata.stateful_dataloader import StatefulDataLoader +import wandb from fastvideo.distill.solver import EulerSolver from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset diff --git a/fastvideo/v1/training/wan_distillation_pipeline.py b/fastvideo/v1/training/wan_distillation_pipeline.py index 2dbc96891..53358782d 100644 --- a/fastvideo/v1/training/wan_distillation_pipeline.py +++ b/fastvideo/v1/training/wan_distillation_pipeline.py @@ -5,9 +5,9 @@ from typing import Dict import torch -import wandb from tqdm.auto import tqdm +import wandb from fastvideo.distill.solver import extract_into_tensor from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs From 017e35d6701d8d963081210d3fd25463a74b42f6 Mon Sep 17 00:00:00 2001 From: Peiyuan Zhang Date: Thu, 12 Jun 2025 20:20:23 +0000 Subject: [PATCH 13/14] remove parquet_dataset --- fastvideo/v1/dataset/parquet_datasets.py | 470 ----------------------- 1 file changed, 470 deletions(-) delete mode 100644 fastvideo/v1/dataset/parquet_datasets.py diff --git a/fastvideo/v1/dataset/parquet_datasets.py b/fastvideo/v1/dataset/parquet_datasets.py deleted file mode 100644 index ffc89bcd0..000000000 --- a/fastvideo/v1/dataset/parquet_datasets.py +++ /dev/null @@ -1,470 +0,0 @@ -import argparse -import json -import os -import random -import time -from collections import defaultdict -from typing import Any, Dict, List - -import numpy as np -import pyarrow.parquet as pq -import torch -import tqdm -from einops import rearrange -from torch import distributed as dist -from torch.utils.data import Dataset -from torchdata.stateful_dataloader import StatefulDataLoader - -from fastvideo.v1.distributed import (get_sp_group, get_sp_parallel_rank, - get_sp_world_size, get_world_rank, - get_world_size) -from fastvideo.v1.logger import init_logger - -logger = init_logger(__name__) - - -class ParquetVideoTextDataset(Dataset): - """Efficient loader for video-text data from a directory of Parquet files.""" - - def __init__(self, - path: str, - batch_size, - cfg_rate: float = 0.0, - num_latent_t: int = 2, - seed: int = 0, - validation: bool = False): - super().__init__() - self.path = str(path) - self.batch_size = batch_size - self.global_rank = get_world_rank() - self.rank_in_sp_group = get_sp_parallel_rank() - self.sp_group = get_sp_group() - self.sp_world_size = get_sp_world_size() - self.world_size = get_world_size() - self.cfg_rate = cfg_rate - self.num_latent_t = num_latent_t - self.local_indices = None - self.validation = validation - - # Negative prompt caching - self.neg_metadata = None - self.cached_neg_prompt: Dict[str, Any] | None = None - - self.plan_output_dir = os.path.join( - self.path, - f"data_plan_world_size_{self.world_size}_sp_size_{self.sp_world_size}.json" - ) - - # group_ranks: a list of lists - # len(group_ranks) = self.world_size - # len(group_ranks[i]) = self.sp_world_size - # group_ranks[i] represents the ranks of the SP group for the i-th GPU - # For example, if self.world_size = 4, self.sp_world_size = 2, then - # group_ranks = [[0, 1], [0, 1], [2, 3], [2, 3]] - sp_group_ranks = get_sp_group().ranks - group_ranks: List[List] = [[] for _ in range(self.world_size)] - dist.all_gather_object(group_ranks, sp_group_ranks) - - if self.global_rank == 0: - # If a plan already exists, then skip creating a new plan - # This will be useful when resume training - if os.path.exists(self.plan_output_dir): - logger.info("Using existing plan from %s", self.plan_output_dir) - else: - logger.info("Creating new plan for %s", self.plan_output_dir) - metadatas = [] - for root, _, files in os.walk(self.path): - for file in sorted(files): - if file.endswith('.parquet'): - file_path = os.path.join(root, file) - num_rows = pq.ParquetFile( - file_path).metadata.num_rows - for row_idx in range(num_rows): - metadatas.append((file_path, row_idx)) - - # the negative prompt is always the first row in the first - # parquet file - if validation: - self.neg_metadata = metadatas[0] - metadatas = metadatas[1:] - - # Generate the plan that distribute rows among workers - random.seed(seed) - random.shuffle(metadatas) - - # Get all sp groups - # e.g. if num_gpus = 4, sp_size = 2 - # group_ranks = [(0, 1), (0, 1), (2, 3), (2, 3)] - # We will assign the same batches of data to ranks in the same sp group, and we'll assign different batches to ranks in different sp groups - # e.g. plan = {0: [row 1, row 4], 1: [row 1, row 4], 2: [row 2, row 3], 3: [row 2, row 3]} - group_ranks_list: List[Any] = list( - set(tuple(r) for r in group_ranks)) - num_sp_groups = len(group_ranks_list) - plan = defaultdict(list) - for idx, metadata in enumerate(metadatas): - sp_group_idx = idx % num_sp_groups - for global_rank in group_ranks_list[sp_group_idx]: - plan[global_rank].append(metadata) - - if validation: - assert self.neg_metadata is not None - plan["negative_prompt"] = [self.neg_metadata] - with open(self.plan_output_dir, "w") as f: - json.dump(plan, f) - else: - pass - dist.barrier() - if validation: - with open(self.plan_output_dir) as f: - plan = json.load(f) - self.neg_metadata = plan["negative_prompt"][0] - - # Add unconditional embeddings for distillation (like in LatentDataset) - self.uncond_prompt_embed = torch.zeros(512, 4096).to(torch.float32) - self.uncond_prompt_mask = torch.zeros(1, 512).bool() - - def _load_and_cache_negative_prompt(self) -> None: - """Load and cache the negative prompt. Only rank 0 in each SP group should call this.""" - if not self.validation or self.neg_metadata is None: - return - - if self.cached_neg_prompt is not None: - return - - # Only rank 0 in each SP group should read the negative prompt - try: - file_path, row_idx = self.neg_metadata - parquet_file = pq.ParquetFile(file_path) - - # Since negative prompt is always the first row (row_idx = 0), - # it's always in the first row group - row_group_index = 0 - local_index = row_idx # This will be 0 for the negative prompt - - row_group = parquet_file.read_row_group(row_group_index).to_pydict() - row_dict = {k: v[local_index] for k, v in row_group.items()} - del row_group - - # Process the negative prompt row - self.cached_neg_prompt = self._process_row(row_dict) - - except Exception as e: - logger.error("Failed to load negative prompt: %s", e) - self.cached_neg_prompt = None - - def get_validation_negative_prompt( - self - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: - """ - Get the negative prompt for validation. - This method ensures the negative prompt is loaded and cached properly. - Returns the processed negative prompt data (latents, embeddings, masks, info). - """ - if not self.validation: - raise ValueError( - "get_validation_negative_prompt() can only be called in validation mode" - ) - - # Load and cache if needed (only rank 0 in SP group will actually load) - if self.cached_neg_prompt is None: - self._load_and_cache_negative_prompt() - - if self.cached_neg_prompt is None: - raise RuntimeError( - f"Rank {self.global_rank} (SP rank {self.rank_in_sp_group}): Could not retrieve negative prompt data" - ) - - # Extract the components - lat, emb, mask, info = (self.cached_neg_prompt["latents"], - self.cached_neg_prompt["embeddings"], - self.cached_neg_prompt["masks"], - self.cached_neg_prompt["info"]) - - # Apply the same processing as in __getitem__ - if lat.numel() == 0: # Validation parquet - return lat, emb, mask, info - else: - lat = lat[:, -self.num_latent_t:] - if self.sp_world_size > 1: - lat = rearrange(lat, - "t (n s) h w -> t n s h w", - n=self.sp_world_size).contiguous() - lat = lat[:, self.rank_in_sp_group, :, :, :] - return lat, emb, mask, info - - def __len__(self): - if self.local_indices is None: - try: - with open(self.plan_output_dir) as f: - plan = json.load(f) - self.local_indices = plan[str(self.global_rank)] - except Exception as err: - raise Exception( - "The data plan hasn't been created yet") from err - assert self.local_indices is not None - return len(self.local_indices) - - def __getitem__(self, idx): - if self.local_indices is None: - try: - with open(self.plan_output_dir) as f: - plan = json.load(f) - self.local_indices = plan[self.global_rank] - except Exception as err: - raise Exception( - "The data plan hasn't been created yet") from err - assert self.local_indices is not None - file_path, row_idx = self.local_indices[idx] - parquet_file = pq.ParquetFile(file_path) - - # Calculate the row group to read into memory and the local idx - # This way we can avoid reading in the entire parquet file - cumulative = 0 - for i in range(parquet_file.num_row_groups): - num_rows = parquet_file.metadata.row_group(i).num_rows - if cumulative + num_rows > row_idx: - row_group_index = i - local_index = row_idx - cumulative - break - cumulative += num_rows - - row_group = parquet_file.read_row_group(row_group_index).to_pydict() - row_dict = {k: v[local_index] for k, v in row_group.items()} - del row_group - - processed = self._process_row(row_dict) - lat, emb, mask, info = processed["latents"], processed[ - "embeddings"], processed["masks"], processed["info"] - if lat.numel() == 0: # Validation parquet - return lat, emb, mask, info - else: - lat = lat[:, -self.num_latent_t:] - if self.sp_world_size > 1: - lat = rearrange(lat, - "t (n s) h w -> t n s h w", - n=self.sp_world_size).contiguous() - lat = lat[:, self.rank_in_sp_group, :, :, :] - return lat, emb, mask, info - - def _process_row(self, row) -> Dict[str, Any]: - """Process a PyArrow batch into tensors.""" - - vae_latent_bytes = row["vae_latent_bytes"] - vae_latent_shape = row["vae_latent_shape"] - text_embedding_bytes = row["text_embedding_bytes"] - text_embedding_shape = row["text_embedding_shape"] - text_attention_mask_bytes = row["text_attention_mask_bytes"] - text_attention_mask_shape = row["text_attention_mask_shape"] - - # Process latent - if not vae_latent_shape: # No VAE latent is stored. Split is validation - lat = np.array([]) - else: - lat = np.frombuffer(vae_latent_bytes, - dtype=np.float32).reshape(vae_latent_shape) - # Make array writable - lat = np.copy(lat) - - if random.random() < self.cfg_rate: - emb = np.zeros((512, 4096), dtype=np.float32) - else: - emb = np.frombuffer(text_embedding_bytes, - dtype=np.float32).reshape(text_embedding_shape) - # Make array writable - emb = np.copy(emb) - if emb.shape[0] < 512: - padded_emb = np.zeros((512, emb.shape[1]), dtype=np.float32) - padded_emb[:emb.shape[0], :] = emb - emb = padded_emb - elif emb.shape[0] > 512: - emb = emb[:512, :] - - # Process mask - if len(text_attention_mask_bytes) > 0 and len( - text_attention_mask_shape) > 0: - msk = np.frombuffer(text_attention_mask_bytes, - dtype=np.uint8).astype(np.bool_) - msk = msk.reshape(1, -1) - # Make array writable - msk = np.copy(msk) - if msk.shape[1] < 512: - padded_msk = np.zeros((1, 512), dtype=np.bool_) - padded_msk[:, :msk.shape[1]] = msk - msk = padded_msk - elif msk.shape[1] > 512: - msk = msk[:, :512] - else: - msk = np.ones((1, 512), dtype=np.bool_) - - # Collect metadata - info = { - "width": row["width"], - "height": row["height"], - "num_frames": row["num_frames"], - "duration_sec": row["duration_sec"], - "fps": row["fps"], - "file_name": row["file_name"], - "caption": row["caption"], - } - - return { - "latents": torch.from_numpy(lat), - "embeddings": torch.from_numpy(emb), - "masks": torch.from_numpy(msk), - "info": info - } - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description='Benchmark Parquet dataset loading speed') - parser.add_argument('--path', - type=str, - default="your/dataset/path", - help='Path to Parquet dataset') - parser.add_argument('--batch_size', - type=int, - default=4, - help='Batch size for DataLoader') - parser.add_argument('--num_batches', - type=int, - default=100, - help='Number of batches to benchmark') - parser.add_argument('--vae_debug', action="store_true") - args = parser.parse_args() - - # Initialize distributed training - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - world_size = int(os.environ.get("WORLD_SIZE", 1)) - rank = int(os.environ.get("RANK", 0)) - - # Initialize CUDA device first - if torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - device = torch.device(f"cuda:{local_rank}") - else: - device = torch.device("cpu") - - # Initialize distributed training - if world_size > 1: - dist.init_process_group(backend="nccl", - init_method="env://", - world_size=world_size, - rank=rank) - print( - f"Initialized process: rank={rank}, local_rank={local_rank}, world_size={world_size}, device={device}" - ) - - # Create dataset - dataset = ParquetVideoTextDataset( - args.path, - batch_size=args.batch_size, - ) - - # Create DataLoader with proper settings - dataloader = StatefulDataLoader( - dataset, - batch_size=args.batch_size, - num_workers=1, # Reduce number of workers to avoid memory issues - prefetch_factor=2, - shuffle=False, - pin_memory=True, - drop_last=True) - - # Example of how to load dataloader state - # if os.path.exists("/workspace/FastVideo/dataloader_state.pt"): - # dataloader_state = torch.load("/workspace/FastVideo/dataloader_state.pt") - # dataloader.load_state_dict(dataloader_state[rank]) - - # Warm-up with synchronization - if rank == 0: - print("Warming up...") - for i, (latents, embeddings, masks, infos) in enumerate(dataloader): - # Example of how to save dataloader state - # if i == 30: - # dist.barrier() - # local_data = {rank: dataloader.state_dict()} - # gathered_data = [None] * world_size - # dist.all_gather_object(gathered_data, local_data) - # if rank == 0: - # global_state_dict = {} - # for d in gathered_data: - # global_state_dict.update(d) - # torch.save(global_state_dict, "dataloader_state.pt") - assert torch.sum(masks[0]).item() == torch.count_nonzero( - embeddings[0]).item() // 4096 - if args.vae_debug: - from diffusers.utils import export_to_video - from diffusers.video_processor import VideoProcessor - - from fastvideo.v1.configs.models.vaes import WanVAEConfig - from fastvideo.v1.fastvideo_args import FastVideoArgs - from fastvideo.v1.models.loader.component_loader import VAELoader - VAE_PATH = "/workspace/data/Wan-AI/Wan2.1-T2V-1.3B-Diffusers/vae" - fastvideo_args = FastVideoArgs( - model_path=VAE_PATH, - vae_config=WanVAEConfig(load_encoder=False), - vae_precision="fp32") - vae_loader = VAELoader() - vae = vae_loader.load(model_path=VAE_PATH, - architecture="", - fastvideo_args=fastvideo_args) - - videoprocessor = VideoProcessor(vae_scale_factor=8) - - with torch.inference_mode(): - video = vae.decode(latents[0].unsqueeze(0).to(device)) - video = videoprocessor.postprocess_video(video) - video_path = os.path.join("/workspace/FastVideo/debug_videos", - infos["caption"][0][:50] + ".mp4") - export_to_video(video[0], video_path, fps=16) - - # Move data to device - # latents = latents.to(device) - # embeddings = embeddings.to(device) - - if world_size > 1: - dist.barrier() - - # Benchmark - if rank == 0: - print(f"Benchmarking with batch_size={args.batch_size}") - start_time = time.time() - total_samples = 0 - for i, (latents, embeddings, masks, - infos) in enumerate(tqdm.tqdm(dataloader, total=args.num_batches)): - if i >= args.num_batches: - break - - # Move data to device - latents = latents.to(device) - embeddings = embeddings.to(device) - - # Calculate actual batch size - batch_size = latents.size(0) - total_samples += batch_size - - # Print progress only from rank 0 - if rank == 0 and (i + 1) % 10 == 0: - elapsed = time.time() - start_time - samples_per_sec = total_samples / elapsed - print( - f"Batch {i+1}/{args.num_batches}, Speed: {samples_per_sec:.2f} samples/sec" - ) - - # Final statistics - if world_size > 1: - dist.barrier() - - if rank == 0: - elapsed = time.time() - start_time - samples_per_sec = total_samples / elapsed - - print("\nBenchmark Results:") - print(f"Total time: {elapsed:.2f} seconds") - print(f"Total samples: {total_samples}") - print(f"Average speed: {samples_per_sec:.2f} samples/sec") - print(f"Time per batch: {elapsed/args.num_batches*1000:.2f} ms") - - if world_size > 1: - dist.destroy_process_group() From f76a3032046249129de1ef0a8d89677fc449ac14 Mon Sep 17 00:00:00 2001 From: Peiyuan Zhang Date: Thu, 12 Jun 2025 20:30:32 +0000 Subject: [PATCH 14/14] remove device str --- fastvideo/v1/fastvideo_args.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/fastvideo/v1/fastvideo_args.py b/fastvideo/v1/fastvideo_args.py index 239fe05c5..ca81822ce 100644 --- a/fastvideo/v1/fastvideo_args.py +++ b/fastvideo/v1/fastvideo_args.py @@ -416,10 +416,6 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs": kwargs[attr] = Mode(mode_value) else: kwargs[attr] = Mode.INFERENCE - elif attr == 'device_str': - kwargs[attr] = getattr( - args, 'device', - None) or "cuda" if torch.cuda.is_available() else "cpu" # Use getattr with default value from the dataclass for potentially missing attributes else: default_value = getattr(cls, attr, None) @@ -646,11 +642,6 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs": else: kwargs[ attr] = Mode.TRAINING # Default to training for TrainingArgs - elif attr == 'device_str': - kwargs[attr] = getattr( - args, 'device', - None) or "cuda" if torch.cuda.is_available() else "cpu" - # Use getattr with default value from the dataclass for potentially missing attributes else: default_value = getattr(cls, attr, None) if getattr(args, attr, default_value) is not None: