diff --git a/fastvideo/distill/solver.py b/fastvideo/distill/solver.py index 7f8fec847..07c0a9765 100644 --- a/fastvideo/distill/solver.py +++ b/fastvideo/distill/solver.py @@ -7,11 +7,27 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import BaseOutput, logging -from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule - logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + + @dataclass class PCMFMSchedulerOutput(BaseOutput): prev_sample: torch.FloatTensor @@ -226,12 +242,12 @@ class EulerSolver: def __init__(self, sigmas, timesteps=1000, euler_timesteps=50): self.step_ratio = timesteps // euler_timesteps + self.euler_timesteps = (np.arange(1, euler_timesteps + 1) * self.step_ratio).round().astype(np.int64) - 1 self.euler_timesteps_prev = np.asarray([0] + self.euler_timesteps[:-1].tolist()) self.sigmas = sigmas[self.euler_timesteps] self.sigmas_prev = np.asarray([sigmas[0]] + sigmas[self.euler_timesteps[:-1]].tolist()) # either use sigma0 or 0 - self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long() self.euler_timesteps_prev = torch.from_numpy(self.euler_timesteps_prev).long() self.sigmas = torch.from_numpy(self.sigmas) diff --git a/fastvideo/v1/fastvideo_args.py b/fastvideo/v1/fastvideo_args.py index be491dbbd..722ef1d3c 100644 --- a/fastvideo/v1/fastvideo_args.py +++ b/fastvideo/v1/fastvideo_args.py @@ -6,8 +6,11 @@ import dataclasses from contextlib import contextmanager from dataclasses import field -from typing import Any, Dict, List, Optional +from enum import Enum +from typing import Any, Callable, List, Optional, Tuple, Dict +import torch +from fastvideo.v1.configs.models import DiTConfig, EncoderConfig, VAEConfig from fastvideo.v1.configs.pipelines.base import PipelineConfig from fastvideo.v1.logger import init_logger from fastvideo.v1.utils import FlexibleArgumentParser, StoreBoolean @@ -15,6 +18,21 @@ logger = init_logger(__name__) +class Mode(Enum): + """Enumeration for FastVideo execution modes.""" + INFERENCE = "inference" + TRAINING = "training" + DISTILL = "distill" + + +def preprocess_text(prompt: str) -> str: + return prompt + + +def postprocess_text(output: Any) -> Any: + raise NotImplementedError + + def clean_cli_args(args: argparse.Namespace) -> Dict[str, Any]: """ Clean the arguments by removing the ones that not explicitly provided by the user. @@ -40,7 +58,7 @@ class FastVideoArgs: # Distributed executor backend distributed_executor_backend: str = "mp" - inference_mode: bool = True # if False == training mode + mode: Mode = Mode.INFERENCE # HuggingFace specific parameters trust_remote_code: bool = False @@ -73,7 +91,15 @@ class FastVideoArgs: @property def training_mode(self) -> bool: - return not self.inference_mode + return self.mode == Mode.TRAINING + + @property + def distill_mode(self) -> bool: + return self.mode == Mode.DISTILL + + @property + def inference_mode(self) -> bool: + return self.mode == Mode.INFERENCE def __post_init__(self): self.check_fastvideo_args() @@ -103,10 +129,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parser.add_argument( - "--inference-mode", - action=StoreBoolean, - default=FastVideoArgs.inference_mode, - help="Whether to use inference mode", + "--mode", + type=str, + default=FastVideoArgs.mode.value, + choices=[mode.value for mode in Mode], + help="The mode to use", ) # HuggingFace specific parameters @@ -238,6 +265,16 @@ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs": pipeline_config = PipelineConfig.from_kwargs(provided_args) kwargs[attr] = pipeline_config # Use getattr with default value from the dataclass for potentially missing attributes + elif attr == 'mode': + # Convert string mode to Mode enum + mode_value = getattr(args, attr, None) + if mode_value: + if isinstance(mode_value, Mode): + kwargs[attr] = mode_value + else: + kwargs[attr] = Mode(mode_value) + else: + kwargs[attr] = Mode.INFERENCE else: default_value = getattr(cls, attr, None) value = getattr(args, attr, default_value) @@ -419,6 +456,8 @@ class TrainingArgs(FastVideoArgs): pred_decay_type: str = "" hunyuan_teacher_disable_cfg: bool = False + use_lora: bool = False + # master_weight_type master_weight_type: str = "" @@ -435,6 +474,17 @@ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs": pipeline_config = PipelineConfig.from_kwargs(provided_args) kwargs[attr] = pipeline_config # Use getattr with default value from the dataclass for potentially missing attributes + elif attr == 'mode': + # Convert string mode to Mode enum + mode_value = getattr(args, attr, None) + if mode_value: + if isinstance(mode_value, Mode): + kwargs[attr] = mode_value + else: + kwargs[attr] = Mode(mode_value) + else: + kwargs[ + attr] = Mode.TRAINING # Default to training for TrainingArgs else: default_value = getattr(cls, attr, None) value = getattr(args, attr, default_value) diff --git a/fastvideo/v1/models/loader/component_loader.py b/fastvideo/v1/models/loader/component_loader.py index 270bc8387..57d3c24f1 100644 --- a/fastvideo/v1/models/loader/component_loader.py +++ b/fastvideo/v1/models/loader/component_loader.py @@ -368,6 +368,7 @@ class TransformerLoader(ComponentLoader): def load(self, model_path: str, architecture: str, fastvideo_args: FastVideoArgs): """Load the transformer based on the model path, architecture, and inference args.""" + logger.info("Loading transformer from %s", model_path) config = get_diffusers_config(model=model_path) hf_config = deepcopy(config) cls_name = config.pop("_class_name") diff --git a/fastvideo/v1/pipelines/composed_pipeline_base.py b/fastvideo/v1/pipelines/composed_pipeline_base.py index e9997b1b8..8fa95d73e 100644 --- a/fastvideo/v1/pipelines/composed_pipeline_base.py +++ b/fastvideo/v1/pipelines/composed_pipeline_base.py @@ -8,6 +8,8 @@ import argparse import os from abc import ABC, abstractmethod +from copy import deepcopy +from enum import Enum from typing import Any, Dict, List, Optional, Union, cast import torch @@ -53,7 +55,15 @@ def __init__(self, """ self.fastvideo_args = fastvideo_args - self.model_path: str = model_path + if fastvideo_args.training_mode or fastvideo_args.distill_mode: + assert isinstance(fastvideo_args, TrainingArgs) + self.training_args = fastvideo_args + assert self.training_args is not None + else: + self.fastvideo_args = fastvideo_args + assert self.fastvideo_args is not None + + self.model_path = model_path self._stages: List[PipelineStage] = [] self._stage_name_mapping: Dict[str, PipelineStage] = {} @@ -79,9 +89,15 @@ def __init__(self, self.initialize_validation_pipeline(self.training_args) self.initialize_training_pipeline(self.training_args) + if fastvideo_args.distill_mode: + assert self.training_args is not None + if self.training_args.log_validation: + self.initialize_validation_pipeline(self.training_args) + self.initialize_distillation_pipeline(self.training_args) + self.initialize_pipeline(fastvideo_args) - if not fastvideo_args.training_mode: + if fastvideo_args.inference_mode: logger.info("Creating pipeline stages...") self.create_pipeline_stages(fastvideo_args) @@ -94,6 +110,10 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs): "if log_validation is True, the pipeline must implement this method" ) + def initialize_distillation_pipeline(self, training_args: TrainingArgs): + raise NotImplementedError( + "if distill_mode is True, the pipeline must implement this method") + @classmethod def from_pretrained(cls, model_path: str, @@ -112,11 +132,18 @@ def from_pretrained(cls, loaded_modules: Optional[Dict[str, torch.nn.Module]] = None, If provided, loaded_modules will be used instead of loading from config/pretrained weights. """ - if args is None or args.inference_mode: - + + # Handle both string mode and Mode enum values + mode_str: str | Enum = getattr( + args, 'mode', "inference") if args is not None else "inference" + if hasattr(mode_str, 'value'): + mode_str = mode_str.value + mode_str = str(mode_str) + + if mode_str == "inference": kwargs['model_path'] = model_path fastvideo_args = FastVideoArgs.from_kwargs(kwargs) - else: + elif mode_str == "training" or mode_str == "distill": assert args is not None, "args must be provided for training mode" fastvideo_args = TrainingArgs.from_cli_args(args) # TODO(will): fix this so that its not so ugly @@ -125,8 +152,8 @@ def from_pretrained(cls, setattr(fastvideo_args, key, value) fastvideo_args.use_cpu_offload = False - # make sure we are in training mode - fastvideo_args.inference_mode = False + # make sure we are in training mode - note: inference_mode is read-only, + # so we don't set it directly here as it's determined by the mode # we hijack the precision to be the master weight type so that the # model is loaded with the correct precision. Subsequently we will # use FSDP2's MixedPrecisionPolicy to set the precision for the @@ -134,6 +161,10 @@ def from_pretrained(cls, # fastvideo_args.precision = fastvideo_args.master_weight_type assert fastvideo_args.pipeline_config.dit_precision == 'fp32', 'only fp32 is supported for training' # assert fastvideo_args.precision == 'fp32', 'only fp32 is supported for training' + else: + raise ValueError(f"Invalid mode: {mode_str}") + + fastvideo_args.check_fastvideo_args() logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args) diff --git a/fastvideo/v1/training/distillation_pipeline.py b/fastvideo/v1/training/distillation_pipeline.py new file mode 100644 index 000000000..c995a36db --- /dev/null +++ b/fastvideo/v1/training/distillation_pipeline.py @@ -0,0 +1,335 @@ +import gc +import os +from abc import ABC, abstractmethod +from typing import List, Optional + +import imageio +import numpy as np +import torch +import torchvision +from diffusers.optimization import get_scheduler +from einops import rearrange +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from torchdata.stateful_dataloader import StatefulDataLoader + +import wandb +from fastvideo.distill.solver import EulerSolver +from fastvideo.v1.configs.sample import SamplingParam +from fastvideo.v1.dataset.parquet_datasets import ParquetVideoTextDataset +from fastvideo.v1.distributed import get_sp_group +from fastvideo.v1.distributed.parallel_state import get_torch_device +from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs +from fastvideo.v1.logger import init_logger +from fastvideo.v1.pipelines import ComposedPipelineBase +from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch + +logger = init_logger(__name__) + + +# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 +def linear_quadratic_schedule( + num_steps: int, + threshold_noise: float, + linear_steps: Optional[int] = None) -> List[float]: + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [ + i * threshold_noise / linear_steps for i in range(linear_steps) + ] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * + quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / ( + quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const + for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + + +def reshard_fsdp(model: torch.nn.Module) -> None: + """Reshard FSDP model for EMA updates.""" + for m in FSDP.fsdp_modules(model): + if m._has_params and m.sharding_strategy is not ShardingStrategy.NO_SHARD: + torch.distributed.fsdp._runtime_utils._reshard(m, m._handle, True) + + +class DistillationPipeline(ComposedPipelineBase, ABC): + """ + A pipeline for distillation training. All distillation pipelines should inherit from this class. + """ + _required_config_modules = ["scheduler", "transformer"] + validation_pipeline: ComposedPipelineBase + + def initialize_distillation_pipeline(self, fastvideo_args: TrainingArgs): + logger.info("Initializing distillation pipeline...") + self.device = get_torch_device() + self.sp_group = get_sp_group() + self.world_size = self.sp_group.world_size + self.rank = self.sp_group.rank + self.local_rank = self.sp_group.local_rank + + # Initialize student model + self.transformer = self.get_module("transformer") + assert self.transformer is not None + + self.transformer.requires_grad_(True) + self.transformer.train() + + # Initialize teacher model without deepcopy to avoid FSDP issues + logger.info("Creating teacher model...") + from fastvideo.v1.models.loader.component_loader import ( + TransformerLoader) + teacher_loader = TransformerLoader() + transformer_path = os.path.join(self.model_path, "transformer") + self.teacher_transformer = teacher_loader.load(transformer_path, "", + fastvideo_args) + self.teacher_transformer.requires_grad_(False) + self.teacher_transformer.eval() + logger.info("Teacher model initialized") + + # Initialize EMA model if needed + if fastvideo_args.use_ema: + logger.info("Creating EMA model...") + ema_loader = TransformerLoader() + self.ema_transformer = ema_loader.load(transformer_path, "", + fastvideo_args) + self.ema_transformer.requires_grad_(False) + self.ema_transformer.eval() + logger.info("EMA model initialized") + else: + self.ema_transformer = None + + noise_scheduler = self.get_module("scheduler") + assert noise_scheduler is not None + + # Initialize solver for distillation + sigmas: torch.Tensor | List[float] = [] + if fastvideo_args.scheduler_type == "pcm_linear_quadratic": + linear_steps = int(noise_scheduler.config.num_train_timesteps * + fastvideo_args.linear_range) + sigmas = linear_quadratic_schedule( + noise_scheduler.config.num_train_timesteps, + fastvideo_args.linear_quadratic_threshold, + linear_steps, + ) + else: + sigmas = noise_scheduler.sigmas + + if isinstance(sigmas, list): + sigmas = torch.tensor(sigmas).to(dtype=torch.float32) + + self.solver = EulerSolver( + sigmas.numpy(), + noise_scheduler.config.num_train_timesteps, + euler_timesteps=fastvideo_args.num_euler_timesteps, + ) + self.solver.to(self.device) + + # Setup optimizer + params_to_optimize = self.transformer.parameters() + params_to_optimize = list( + filter(lambda p: p.requires_grad, params_to_optimize)) + + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=fastvideo_args.learning_rate, + betas=(0.9, 0.999), + weight_decay=fastvideo_args.weight_decay, + eps=1e-8, + ) + + init_steps = 0 + logger.info("optimizer: %s", optimizer) + + # Setup lr scheduler + lr_scheduler = get_scheduler( + fastvideo_args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=fastvideo_args.lr_warmup_steps * self.world_size, + num_training_steps=fastvideo_args.max_train_steps * self.world_size, + num_cycles=fastvideo_args.lr_num_cycles, + power=fastvideo_args.lr_power, + last_epoch=init_steps - 1, + ) + + # Setup dataset + train_dataset = ParquetVideoTextDataset( + fastvideo_args.data_path, + batch_size=fastvideo_args.train_batch_size, + cfg_rate=fastvideo_args.cfg, + num_latent_t=fastvideo_args.num_latent_t) + + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=fastvideo_args.train_batch_size, + num_workers=fastvideo_args.dataloader_num_workers, + prefetch_factor=2, + shuffle=False, + pin_memory=True, + drop_last=True) + + self.lr_scheduler = lr_scheduler + self.train_dataset = train_dataset + self.train_dataloader = train_dataloader + self.init_steps = init_steps + self.optimizer = optimizer + self.noise_scheduler = noise_scheduler + + # Get unconditional embeddings + self.uncond_prompt_embed = torch.zeros(512, 4096).to(torch.float32) + self.uncond_prompt_mask = torch.zeros(1, 512).bool() + # self.uncond_prompt_embed = train_dataset.uncond_prompt_embed + # self.uncond_prompt_mask = train_dataset.uncond_prompt_mask + + if self.rank <= 0: + project = fastvideo_args.tracker_project_name or "fastvideo" + wandb.init(project=project, config=fastvideo_args) + + @abstractmethod + def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): + raise NotImplementedError( + "Distillation pipelines must implement this method") + + @abstractmethod + def distill_one_step(self, transformer, model_type, teacher_transformer, + ema_transformer, optimizer, lr_scheduler, loader_iter, + noise_scheduler, solver, noise_random_generator, + gradient_accumulation_steps, sp_size, max_grad_norm, + uncond_prompt_embed, uncond_prompt_mask, + num_euler_timesteps, multiphase, not_apply_cfg_solver, + distill_cfg, ema_decay, pred_decay_weight, + pred_decay_type, hunyuan_teacher_disable_cfg): + """ + Distill one step of the model. + """ + raise NotImplementedError( + "Distillation pipeline must implement this method") + + @torch.no_grad() + def _log_validation(self, transformer, fastvideo_args, global_step): + """Log validation results during training.""" + fastvideo_args.mode = Mode.INFERENCE + fastvideo_args.use_cpu_offload = False + if not fastvideo_args.log_validation: + return + if self.validation_pipeline is None: + raise ValueError("Validation pipeline is not set") + + # Create sampling parameters if not provided + sampling_param = SamplingParam.from_pretrained( + fastvideo_args.model_path) + + # Prepare validation prompts + validation_dataset = ParquetVideoTextDataset( + fastvideo_args.validation_prompt_dir, + batch_size=1, + cfg_rate=fastvideo_args.cfg, + num_latent_t=fastvideo_args.num_latent_t, + validation=True) + + validation_dataloader = StatefulDataLoader(validation_dataset, + batch_size=1, + num_workers=1, + prefetch_factor=2, + shuffle=False, + pin_memory=True, + drop_last=False) + + transformer.eval() + + # Process validation prompts + videos = [] + captions = [] + for _, embeddings, masks, infos in validation_dataloader: + logger.info("infos: %s", infos) + caption = infos['caption'] + captions.append(caption) + prompt_embeds = embeddings.to(fastvideo_args.device) + prompt_attention_mask = masks.to(fastvideo_args.device) + + # Calculate sizes + latents_size = [(sampling_param.num_frames - 1) // 4 + 1, + sampling_param.height // 8, + sampling_param.width // 8] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + + # Prepare batch for validation + batch = ForwardBatch( + data_type="video", + latents=None, + prompt_embeds=[prompt_embeds], + prompt_attention_mask=[prompt_attention_mask], + height=fastvideo_args.num_height, + width=fastvideo_args.num_width, + num_frames=fastvideo_args.num_frames, + num_inference_steps=10, + guidance_scale=1, + n_tokens=n_tokens, + do_classifier_free_guidance=False, + eta=0.0, + extra={}, + ) + + # Run validation inference + with torch.inference_mode(): + output_batch = self.validation_pipeline.forward( + batch, fastvideo_args) + samples = output_batch.output + + # Process outputs + video = rearrange(samples, "b c t h w -> t b c h w") + frames = [] + for x in video: + x = torchvision.utils.make_grid(x, nrow=6) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + frames.append((x * 255).numpy().astype(np.uint8)) + # videos.append(frames) + videos = [frames] + + video_filenames = [] + video_captions = [] + for i, video in enumerate(videos): + caption = captions[i] + filename = os.path.join( + fastvideo_args.output_dir, + f"validation_step_{global_step}_video_{i}.mp4") + imageio.mimsave(filename, video, fps=sampling_param.fps) + video_filenames.append(filename) + video_captions.append(caption) + + # Log validation results + if self.rank == 0: + video_filenames = [] + video_captions = [] + for i, video in enumerate(videos): + caption = captions[i] + filename = os.path.join( + fastvideo_args.output_dir, + f"validation_step_{global_step}_video_{i}.mp4") + imageio.mimsave(filename, video, fps=sampling_param.fps) + video_filenames.append(filename) + video_captions.append(caption) + + logs = { + "validation_videos": [ + wandb.Video(filename, + caption=caption) for filename, caption in zip( + video_filenames, video_captions) + ] + } + wandb.log(logs, step=global_step) + + # Re-enable gradients for training + fastvideo_args.mode = Mode.DISTILL + transformer.requires_grad_(True) + transformer.train() + + gc.collect() + torch.cuda.empty_cache() diff --git a/fastvideo/v1/training/training_pipeline.py b/fastvideo/v1/training/training_pipeline.py index 681d587b1..01f5ad69b 100644 --- a/fastvideo/v1/training/training_pipeline.py +++ b/fastvideo/v1/training/training_pipeline.py @@ -19,7 +19,7 @@ from fastvideo.v1.dataset import build_parquet_map_style_dataloader from fastvideo.v1.distributed import (get_sp_group, get_torch_device, get_world_group) -from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines import ComposedPipelineBase @@ -146,7 +146,7 @@ def train_one_step(self, transformer, model_type, optimizer, lr_scheduler, @torch.no_grad() def _log_validation(self, transformer, training_args, global_step) -> None: assert training_args is not None - training_args.inference_mode = True + training_args.mode = Mode.INFERENCE training_args.use_cpu_offload = False if not training_args.log_validation: return diff --git a/fastvideo/v1/training/wan_distillation_pipeline.py b/fastvideo/v1/training/wan_distillation_pipeline.py new file mode 100644 index 000000000..53358782d --- /dev/null +++ b/fastvideo/v1/training/wan_distillation_pipeline.py @@ -0,0 +1,496 @@ +import sys +import time +from collections import deque +from copy import deepcopy +from typing import Dict + +import torch +from tqdm.auto import tqdm + +import wandb +from fastvideo.distill.solver import extract_into_tensor +from fastvideo.v1.distributed import cleanup_dist_env_and_memory, get_sp_group +from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs +from fastvideo.v1.forward_context import set_forward_context +from fastvideo.v1.logger import init_logger +from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch +from fastvideo.v1.pipelines.wan.wan_pipeline import WanValidationPipeline +from fastvideo.v1.training.distillation_pipeline import (DistillationPipeline, + reshard_fsdp) +from fastvideo.v1.training.training_utils import ( + clip_grad_norm_while_handling_failing_dtensor_cases, normalize_dit_input, + save_checkpoint) + +logger = init_logger(__name__) + + +def get_norm(model_pred: torch.Tensor, norms: Dict[str, float], + gradient_accumulation_steps: int) -> None: + """Calculate and aggregate model prediction norms.""" + fro_norm = ( + torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore + gradient_accumulation_steps) + largest_singular_value = (torch.linalg.matrix_norm(model_pred, ord=2) / + gradient_accumulation_steps) + absolute_mean = torch.mean( + torch.abs(model_pred)) / gradient_accumulation_steps + absolute_max = torch.max( + torch.abs(model_pred)) / gradient_accumulation_steps + + sp_group = get_sp_group() + sp_group.all_reduce(fro_norm, op=torch.distributed.ReduceOp.AVG) + sp_group.all_reduce(largest_singular_value, + op=torch.distributed.ReduceOp.AVG) + sp_group.all_reduce(absolute_mean, op=torch.distributed.ReduceOp.AVG) + + norms["fro"] += torch.mean(fro_norm).item() # codespell:ignore + norms["largest singular value"] += torch.mean(largest_singular_value).item() + norms["absolute mean"] += absolute_mean.item() + norms["absolute max"] += absolute_max.item() + + +class WanDistillationPipeline(DistillationPipeline): + """ + A distillation pipeline for Wan. + """ + _required_config_modules = ["scheduler", "transformer"] + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + pass + + def create_training_stages(self, fastvideo_args: FastVideoArgs): + pass + + def initialize_validation_pipeline(self, fastvideo_args: FastVideoArgs): + logger.info("Initializing validation pipeline...") + args_copy = deepcopy(fastvideo_args) + + args_copy.mode = Mode.INFERENCE + args_copy.vae_config.load_encoder = False + validation_pipeline = WanValidationPipeline.from_pretrained( + fastvideo_args.model_path, + args=None, + mode=Mode.INFERENCE, + loaded_modules={"transformer": self.get_module("transformer")}) + + self.validation_pipeline = validation_pipeline + + def distill_one_step( + self, + transformer, + model_type, + teacher_transformer, + ema_transformer, + optimizer, + lr_scheduler, + loader_iter, + noise_scheduler, + solver, + noise_random_generator, + gradient_accumulation_steps, + sp_size, + max_grad_norm, + uncond_prompt_embed, + uncond_prompt_mask, + num_euler_timesteps, + multiphase, + not_apply_cfg_solver, + distill_cfg, + ema_decay, + pred_decay_weight, + pred_decay_type, + hunyuan_teacher_disable_cfg, + ) -> tuple[float, float, Dict[str, float]]: + """Perform one step of distillation training.""" + total_loss = 0.0 + optimizer.zero_grad() + model_pred_norm = { + "fro": 0.0, # codespell:ignore + "largest singular value": 0.0, + "absolute mean": 0.0, + "absolute max": 0.0, + } + + for _ in range(gradient_accumulation_steps): + ( + latents, + encoder_hidden_states, + encoder_attention_mask, + infos, + ) = next(loader_iter) + + latents = latents.to(self.device, dtype=torch.bfloat16) + encoder_hidden_states = encoder_hidden_states.to( + self.device, dtype=torch.bfloat16) + + latents = normalize_dit_input(model_type, latents) + batch_size = latents.shape[0] + noise = torch.randn_like(latents) + + indices = torch.randint(0, + num_euler_timesteps, (batch_size, ), + device=latents.device).long() + + if sp_size > 1: + self.sp_group.broadcast(indices, src=0) + + # Add noise according to flow matching + sigmas = extract_into_tensor(solver.sigmas, indices, latents.shape) + sigmas_prev = extract_into_tensor(solver.sigmas_prev, indices, + latents.shape) + + timesteps = (sigmas * + noise_scheduler.config.num_train_timesteps).view(-1) + timesteps_prev = ( + sigmas_prev * + noise_scheduler.config.num_train_timesteps).view(-1) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents + noisy_model_input = noisy_model_input.to(torch.bfloat16) + + # Get student model prediction + with torch.autocast("cuda", dtype=torch.bfloat16): + input_kwargs = { + "hidden_states": noisy_model_input, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timesteps, + "encoder_attention_mask": encoder_attention_mask, + "return_dict": False, + } + if hunyuan_teacher_disable_cfg: + input_kwargs["guidance"] = torch.tensor( + [1000.0], + device=noisy_model_input.device, + dtype=torch.bfloat16) + + with set_forward_context(current_timestep=timesteps, + attn_metadata=None): + model_pred = transformer(**input_kwargs)[0] + + # Apply multi-phase prediction + model_pred, end_index = solver.euler_style_multiphase_pred( + noisy_model_input, model_pred, indices, multiphase) + + # Get teacher model prediction + with torch.no_grad(), torch.autocast( + "cuda", dtype=torch.bfloat16), set_forward_context( + current_timestep=timesteps, attn_metadata=None): + cond_teacher_output = teacher_transformer( + noisy_model_input, + encoder_hidden_states, + timesteps, + encoder_attention_mask, + return_dict=False, + )[0].float() + + if not_apply_cfg_solver: + uncond_teacher_output = cond_teacher_output + else: + # Get teacher model prediction on unconditional embedding + with torch.autocast("cuda", dtype=torch.bfloat16): + input_kwargs = { + "hidden_states": + noisy_model_input, + "encoder_hidden_states": + uncond_prompt_embed.unsqueeze(0).expand( + batch_size, -1, -1), + "timestep": + timesteps, + "encoder_attention_mask": + uncond_prompt_mask.unsqueeze(0).expand( + batch_size, -1), + "return_dict": + False, + } + with set_forward_context(current_timestep=timesteps, + attn_metadata=None): + uncond_teacher_output = teacher_transformer( + **input_kwargs)[0] + teacher_output = uncond_teacher_output + distill_cfg * ( + cond_teacher_output - uncond_teacher_output) + x_prev = solver.euler_step(noisy_model_input, teacher_output, + indices).to(torch.bfloat16) + + # Get target prediction + with torch.no_grad(): + with torch.autocast("cuda", dtype=torch.bfloat16): + if ema_transformer is not None: + with set_forward_context( + current_timestep=timesteps_prev, + attn_metadata=None): + target_pred = ema_transformer( + x_prev, + encoder_hidden_states, + timesteps_prev, + encoder_attention_mask, + return_dict=False, + )[0] + else: + with set_forward_context( + current_timestep=timesteps_prev, + attn_metadata=None): + target_pred = transformer( + x_prev, + encoder_hidden_states, + timesteps_prev, + encoder_attention_mask, + return_dict=False, + )[0] + + target, end_index = solver.euler_style_multiphase_pred( + x_prev, target_pred, indices, multiphase, True) + + # Calculate loss + huber_c = 0.001 + loss = (torch.mean( + torch.sqrt((model_pred.float() - target.float())**2 + + huber_c**2) - huber_c) / gradient_accumulation_steps) + + if pred_decay_weight > 0: + if pred_decay_type == "l1": + pred_decay_loss = ( + torch.mean(torch.sqrt(model_pred.float()**2)) * + pred_decay_weight / gradient_accumulation_steps) + loss += pred_decay_loss + elif pred_decay_type == "l2": + pred_decay_loss = (torch.mean(model_pred.float()**2) * + pred_decay_weight / + gradient_accumulation_steps) + loss += pred_decay_loss + else: + raise NotImplementedError( + "pred_decay_type is not implemented") + + # Calculate model prediction norms + get_norm(model_pred.detach().float(), model_pred_norm, + gradient_accumulation_steps) + loss.backward() + + avg_loss = loss.detach().clone() + self.sp_group.all_reduce(avg_loss, + op=torch.distributed.ReduceOp.AVG) + total_loss += avg_loss.item() + + # Update EMA + if ema_transformer is not None: + reshard_fsdp(ema_transformer) + for p_averaged, p_model in zip(ema_transformer.parameters(), + transformer.parameters()): + with torch.no_grad(): + p_averaged.copy_( + torch.lerp(p_averaged.detach(), p_model.detach(), + 1 - ema_decay)) + + # Gradient clipping and optimization step + if max_grad_norm is not None: + model_parts = [transformer] + grad_norm = clip_grad_norm_while_handling_failing_dtensor_cases( + [p for m in model_parts for p in m.parameters()], + max_grad_norm, + foreach=None, + ) + grad_norm = grad_norm.item() if grad_norm is not None else 0.0 + else: + grad_norm = 0.0 + + optimizer.step() + lr_scheduler.step() + + return total_loss, grad_norm, model_pred_norm + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: TrainingArgs, + ): + assert self.training_args is not None + train_dataloader = self.train_dataloader + init_steps = self.init_steps + lr_scheduler = self.lr_scheduler + optimizer = self.optimizer + noise_scheduler = self.noise_scheduler + solver = self.solver + noise_random_generator = None + uncond_prompt_embed = self.uncond_prompt_embed + uncond_prompt_mask = self.uncond_prompt_mask + + assert self.training_args.sp_size is not None + assert self.training_args.gradient_accumulation_steps is not None + total_batch_size = (self.world_size * + self.training_args.gradient_accumulation_steps / + self.training_args.sp_size * + self.training_args.train_sp_batch_size) + logger.info("***** Running distillation training *****") + logger.info(" Resume training from step %s", init_steps) + logger.info(" Instantaneous batch size per device = %s", + self.training_args.train_batch_size) + logger.info( + " Total train batch size (w. data & sequence parallel, accumulation) = %s", + total_batch_size) + logger.info(" Gradient Accumulation steps = %s", + self.training_args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %s", + self.training_args.max_train_steps) + logger.info( + " Total training parameters per FSDP shard = %s B", + sum(p.numel() + for p in self.transformer.parameters() if p.requires_grad) / + 1e9) + logger.info(" Master weight dtype: %s", + self.transformer.parameters().__next__().dtype) + + # Potentially load in the weights and states from a previous save + if self.training_args.resume_from_checkpoint: + raise NotImplementedError( + "resume_from_checkpoint is not supported now.") + + progress_bar = tqdm( + range(0, self.training_args.max_train_steps), + initial=init_steps, + desc="Steps", + disable=self.local_rank > 0, + ) + + loader_iter = iter(train_dataloader) + step_times: deque[float] = deque(maxlen=100) + + # Skip steps if resuming + for i in range(init_steps): + next(loader_iter) + + def get_num_phases(multi_phased_distill_schedule: str, + step: int) -> int: + # step-phase,step-phase + multi_phases = multi_phased_distill_schedule.split(",") + phase = multi_phases[-1].split("-")[-1] + for step_phases in multi_phases: + phase_step, phase = step_phases.split("-") + if step <= int(phase_step): + return int(phase) + return int(phase) + + for step in range(init_steps + 1, + self.training_args.max_train_steps + 1): + start_time = time.perf_counter() + + assert self.training_args.multi_phased_distill_schedule is not None + num_phases = get_num_phases( + self.training_args.multi_phased_distill_schedule, step) + try: + loss, grad_norm, pred_norm = self.distill_one_step( + self.transformer, + "wan", # model_type + self.teacher_transformer, + self.ema_transformer, + optimizer, + lr_scheduler, + loader_iter, + noise_scheduler, + solver, + noise_random_generator, + self.training_args.gradient_accumulation_steps, + self.training_args.sp_size, + self.training_args.max_grad_norm, + uncond_prompt_embed, + uncond_prompt_mask, + self.training_args.num_euler_timesteps, + num_phases, + self.training_args.not_apply_cfg_solver, + self.training_args.distill_cfg, + self.training_args.ema_decay, + self.training_args.pred_decay_weight, + self.training_args.pred_decay_type, + self.training_args.hunyuan_teacher_disable_cfg, + ) + + step_time = time.perf_counter() - start_time + step_times.append(step_time) + avg_step_time = sum(step_times) / len(step_times) + + progress_bar.set_postfix({ + "loss": f"{loss:.4f}", + "step_time": f"{step_time:.2f}s", + "grad_norm": grad_norm, + "phases": num_phases, + }) + progress_bar.update(1) + except StopIteration: + loader_iter = iter(train_dataloader) + step -= 1 + continue + + if self.rank <= 0: + wandb.log( + { + "train_loss": + loss, + "learning_rate": + lr_scheduler.get_last_lr()[0], + "step_time": + step_time, + "avg_step_time": + avg_step_time, + "grad_norm": + grad_norm, + "pred_fro_norm": + pred_norm["fro"], # codespell:ignore + "pred_largest_singular_value": + pred_norm["largest singular value"], + "pred_absolute_mean": + pred_norm["absolute mean"], + "pred_absolute_max": + pred_norm["absolute max"], + "phases": + num_phases, + }, + step=step, + ) + + if step % self.training_args.checkpointing_steps == 0: + if self.training_args.use_lora: + raise NotImplementedError("LoRA is not supported now") + else: + if self.training_args.use_ema: + save_checkpoint(self.ema_transformer, self.rank, + self.training_args.output_dir, step) + else: + save_checkpoint(self.transformer, self.rank, + self.training_args.output_dir, step) + self.sp_group.barrier() + + if self.training_args.log_validation and step % self.training_args.validation_steps == 0: + self._log_validation(self.transformer, self.training_args, step) + + # Final checkpoint + if self.training_args.use_lora: + raise NotImplementedError("LoRA is not supported now") + else: + save_checkpoint(self.transformer, self.rank, + self.training_args.output_dir, + self.training_args.max_train_steps) + + if get_sp_group(): + cleanup_dist_env_and_memory() + + +def main(args) -> None: + logger.info("Starting distillation pipeline...") + + pipeline = WanDistillationPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + + args = pipeline.training_args + pipeline.forward(None, args) + logger.info("Distillation pipeline done") + + +if __name__ == "__main__": + argv = sys.argv + from fastvideo.v1.fastvideo_args import TrainingArgs + from fastvideo.v1.utils import FlexibleArgumentParser + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + args.use_cpu_offload = False + print(args) + main(args) diff --git a/fastvideo/v1/training/wan_training_pipeline.py b/fastvideo/v1/training/wan_training_pipeline.py index e616af167..cdd84bd18 100644 --- a/fastvideo/v1/training/wan_training_pipeline.py +++ b/fastvideo/v1/training/wan_training_pipeline.py @@ -12,7 +12,7 @@ from fastvideo.v1.distributed import (cleanup_dist_env_and_memory, get_sp_group, get_torch_device, get_world_group) -from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.v1.fastvideo_args import FastVideoArgs, Mode, TrainingArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger from fastvideo.v1.models.schedulers.scheduling_flow_unipc_multistep import ( @@ -53,12 +53,12 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs): logger.info("Initializing validation pipeline...") args_copy = deepcopy(training_args) - args_copy.inference_mode = True + args_copy.mode = Mode.INFERENCE args_copy.pipeline_config.vae_config.load_encoder = False validation_pipeline = WanValidationPipeline.from_pretrained( training_args.model_path, args=None, - inference_mode=True, + mode=Mode.INFERENCE, loaded_modules={"transformer": self.get_module("transformer")}, tp_size=training_args.tp_size, sp_size=training_args.sp_size, diff --git a/scripts/distill/distill_v1.sh b/scripts/distill/distill_v1.sh new file mode 100755 index 000000000..37c93f8d5 --- /dev/null +++ b/scripts/distill/distill_v1.sh @@ -0,0 +1,55 @@ +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online + +DATA_DIR=data/HD-Mixkit-Finetune-Wan/combined_parquet_dataset +VALIDATION_DIR=data/HD-Mixkit-Finetune-Wan/validation_parquet_dataset +num_gpus=1 +# IP=[MASTER NODE IP] + +# If you do not have 32 GPUs and to fit in memory, you can: 1. increase sp_size. 2. reduce num_latent_t + # --gradient_checkpointing\ + # --pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo \ + # --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ +torchrun --nnodes 1 --nproc_per_node $num_gpus\ + fastvideo/v1/training/wan_distillation_pipeline.py \ + --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --mode distill \ + --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --cache_dir "/home/test/.cache" \ + --data_path "$DATA_DIR" \ + --validation_prompt_dir "$VALIDATION_DIR" \ + --train_batch_size=1 \ + --num_latent_t 4 \ + --sp_size $num_gpus \ + --dp_size $num_gpus \ + --dp_shards $num_gpus \ + --train_sp_batch_size 1 \ + --dataloader_num_workers $num_gpus \ + --gradient_accumulation_steps=1 \ + --max_train_steps=540 \ + --learning_rate=1e-6 \ + --mixed_precision="bf16" \ + --checkpointing_steps=64 \ + --validation_steps 180 \ + --validation_sampling_steps "2,4,8" \ + --checkpoints_total_limit 3 \ + --allow_tf32 \ + --ema_start_step 0 \ + --cfg 0.0 \ + --log_validation \ + --output_dir="$DATA_DIR/outputs/hy_phase1_shift17_bs_16_HD" \ + --tracker_project_name Hunyuan_Distill \ + --num_height 720 \ + --num_width 1280 \ + --num_frames 81 \ + --shift 17 \ + --validation_guidance_scale "1.0" \ + --num_euler_timesteps 50 \ + --multi_phased_distill_schedule "4000-1" \ + --not_apply_cfg_solver \ + --weight_decay 0.01 \ + --master_weight_type "fp32" \ + --distill_cfg 3.0 \ + --pred_decay_weight 0.0 \ + --max_grad_norm 1.0 + # --master_weight_type "bf16" \ No newline at end of file diff --git a/scripts/finetune/finetune_v1.sh b/scripts/finetune/finetune_v1.sh index 317f625c9..13d13a359 100644 --- a/scripts/finetune/finetune_v1.sh +++ b/scripts/finetune/finetune_v1.sh @@ -12,7 +12,7 @@ NUM_GPUS=4 torchrun --nnodes 1 --nproc_per_node $NUM_GPUS\ fastvideo/v1/training/wan_training_pipeline.py\ --model_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ - --inference_mode False\ + --mode training\ --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ --data_path "$DATA_DIR"\ --validation_prompt_dir "$VALIDATION_DIR"\