Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions fastvideo/distill/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,27 @@
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput, logging

from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
if linear_steps is None:
linear_steps = num_steps // 2
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
quadratic_steps = num_steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
const = quadratic_coef * (linear_steps**2)
quadratic_sigma_schedule = [
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
sigma_schedule = [1.0 - x for x in sigma_schedule]
return sigma_schedule


@dataclass
class PCMFMSchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor
Expand Down Expand Up @@ -226,12 +242,12 @@ class EulerSolver:

def __init__(self, sigmas, timesteps=1000, euler_timesteps=50):
self.step_ratio = timesteps // euler_timesteps

self.euler_timesteps = (np.arange(1, euler_timesteps + 1) * self.step_ratio).round().astype(np.int64) - 1
self.euler_timesteps_prev = np.asarray([0] + self.euler_timesteps[:-1].tolist())
self.sigmas = sigmas[self.euler_timesteps]
self.sigmas_prev = np.asarray([sigmas[0]] +
sigmas[self.euler_timesteps[:-1]].tolist()) # either use sigma0 or 0

self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long()
self.euler_timesteps_prev = torch.from_numpy(self.euler_timesteps_prev).long()
self.sigmas = torch.from_numpy(self.sigmas)
Expand Down
64 changes: 57 additions & 7 deletions fastvideo/v1/fastvideo_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,33 @@
import dataclasses
from contextlib import contextmanager
from dataclasses import field
from typing import Any, Dict, List, Optional
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Dict

import torch
from fastvideo.v1.configs.models import DiTConfig, EncoderConfig, VAEConfig
from fastvideo.v1.configs.pipelines.base import PipelineConfig
from fastvideo.v1.logger import init_logger
from fastvideo.v1.utils import FlexibleArgumentParser, StoreBoolean

logger = init_logger(__name__)


class Mode(Enum):
"""Enumeration for FastVideo execution modes."""
INFERENCE = "inference"
TRAINING = "training"
DISTILL = "distill"


def preprocess_text(prompt: str) -> str:
return prompt


def postprocess_text(output: Any) -> Any:
raise NotImplementedError


def clean_cli_args(args: argparse.Namespace) -> Dict[str, Any]:
"""
Clean the arguments by removing the ones that not explicitly provided by the user.
Expand All @@ -40,7 +58,7 @@
# Distributed executor backend
distributed_executor_backend: str = "mp"

inference_mode: bool = True # if False == training mode
mode: Mode = Mode.INFERENCE

# HuggingFace specific parameters
trust_remote_code: bool = False
Expand Down Expand Up @@ -73,7 +91,15 @@

@property
def training_mode(self) -> bool:
return not self.inference_mode
return self.mode == Mode.TRAINING

@property
def distill_mode(self) -> bool:
return self.mode == Mode.DISTILL

@property
def inference_mode(self) -> bool:
return self.mode == Mode.INFERENCE

def __post_init__(self):
self.check_fastvideo_args()
Expand Down Expand Up @@ -103,10 +129,11 @@
)

parser.add_argument(
"--inference-mode",
action=StoreBoolean,
default=FastVideoArgs.inference_mode,
help="Whether to use inference mode",
"--mode",
type=str,
default=FastVideoArgs.mode.value,
choices=[mode.value for mode in Mode],
help="The mode to use",
)

# HuggingFace specific parameters
Expand Down Expand Up @@ -238,6 +265,16 @@
pipeline_config = PipelineConfig.from_kwargs(provided_args)
kwargs[attr] = pipeline_config
# Use getattr with default value from the dataclass for potentially missing attributes
elif attr == 'mode':
# Convert string mode to Mode enum
mode_value = getattr(args, attr, None)
if mode_value:

Check failure on line 271 in fastvideo/v1/fastvideo_args.py

View workflow job for this annotation

GitHub Actions / pre-commit / pre-commit

Incompatible types in assignment (expression has type "Mode", target has type "PipelineConfig") [assignment]
if isinstance(mode_value, Mode):
kwargs[attr] = mode_value

Check failure on line 273 in fastvideo/v1/fastvideo_args.py

View workflow job for this annotation

GitHub Actions / pre-commit / pre-commit

Incompatible types in assignment (expression has type "Mode", target has type "PipelineConfig") [assignment]
else:
kwargs[attr] = Mode(mode_value)

Check failure on line 275 in fastvideo/v1/fastvideo_args.py

View workflow job for this annotation

GitHub Actions / pre-commit / pre-commit

Incompatible types in assignment (expression has type "Mode", target has type "PipelineConfig") [assignment]
else:
kwargs[attr] = Mode.INFERENCE
else:
default_value = getattr(cls, attr, None)
value = getattr(args, attr, default_value)
Expand Down Expand Up @@ -419,6 +456,8 @@
pred_decay_type: str = ""
hunyuan_teacher_disable_cfg: bool = False

use_lora: bool = False

# master_weight_type
master_weight_type: str = ""

Expand All @@ -435,6 +474,17 @@
pipeline_config = PipelineConfig.from_kwargs(provided_args)
kwargs[attr] = pipeline_config
# Use getattr with default value from the dataclass for potentially missing attributes
elif attr == 'mode':
# Convert string mode to Mode enum
mode_value = getattr(args, attr, None)
if mode_value:

Check failure on line 480 in fastvideo/v1/fastvideo_args.py

View workflow job for this annotation

GitHub Actions / pre-commit / pre-commit

Incompatible types in assignment (expression has type "Mode", target has type "PipelineConfig") [assignment]
if isinstance(mode_value, Mode):
kwargs[attr] = mode_value

Check failure on line 482 in fastvideo/v1/fastvideo_args.py

View workflow job for this annotation

GitHub Actions / pre-commit / pre-commit

Incompatible types in assignment (expression has type "Mode", target has type "PipelineConfig") [assignment]
else:
kwargs[attr] = Mode(mode_value)
else:

Check failure on line 485 in fastvideo/v1/fastvideo_args.py

View workflow job for this annotation

GitHub Actions / pre-commit / pre-commit

Incompatible types in assignment (expression has type "Mode", target has type "PipelineConfig") [assignment]
kwargs[
attr] = Mode.TRAINING # Default to training for TrainingArgs
else:
default_value = getattr(cls, attr, None)
value = getattr(args, attr, default_value)
Expand Down
1 change: 1 addition & 0 deletions fastvideo/v1/models/loader/component_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ class TransformerLoader(ComponentLoader):
def load(self, model_path: str, architecture: str,
fastvideo_args: FastVideoArgs):
"""Load the transformer based on the model path, architecture, and inference args."""
logger.info("Loading transformer from %s", model_path)
config = get_diffusers_config(model=model_path)
hf_config = deepcopy(config)
cls_name = config.pop("_class_name")
Expand Down
45 changes: 38 additions & 7 deletions fastvideo/v1/pipelines/composed_pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import argparse
import os
from abc import ABC, abstractmethod
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Union, cast

import torch
Expand Down Expand Up @@ -53,7 +55,15 @@ def __init__(self,
"""
self.fastvideo_args = fastvideo_args

self.model_path: str = model_path
if fastvideo_args.training_mode or fastvideo_args.distill_mode:
assert isinstance(fastvideo_args, TrainingArgs)
self.training_args = fastvideo_args
assert self.training_args is not None
else:
self.fastvideo_args = fastvideo_args
assert self.fastvideo_args is not None

self.model_path = model_path
self._stages: List[PipelineStage] = []
self._stage_name_mapping: Dict[str, PipelineStage] = {}

Expand All @@ -79,9 +89,15 @@ def __init__(self,
self.initialize_validation_pipeline(self.training_args)
self.initialize_training_pipeline(self.training_args)

if fastvideo_args.distill_mode:
assert self.training_args is not None
if self.training_args.log_validation:
self.initialize_validation_pipeline(self.training_args)
self.initialize_distillation_pipeline(self.training_args)

self.initialize_pipeline(fastvideo_args)

if not fastvideo_args.training_mode:
if fastvideo_args.inference_mode:
logger.info("Creating pipeline stages...")
self.create_pipeline_stages(fastvideo_args)

Expand All @@ -94,6 +110,10 @@ def initialize_validation_pipeline(self, training_args: TrainingArgs):
"if log_validation is True, the pipeline must implement this method"
)

def initialize_distillation_pipeline(self, training_args: TrainingArgs):
raise NotImplementedError(
"if distill_mode is True, the pipeline must implement this method")

@classmethod
def from_pretrained(cls,
model_path: str,
Expand All @@ -112,11 +132,18 @@ def from_pretrained(cls,
loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
If provided, loaded_modules will be used instead of loading from config/pretrained weights.
"""
if args is None or args.inference_mode:


# Handle both string mode and Mode enum values
mode_str: str | Enum = getattr(
args, 'mode', "inference") if args is not None else "inference"
if hasattr(mode_str, 'value'):
mode_str = mode_str.value
mode_str = str(mode_str)

if mode_str == "inference":
kwargs['model_path'] = model_path
fastvideo_args = FastVideoArgs.from_kwargs(kwargs)
else:
elif mode_str == "training" or mode_str == "distill":
assert args is not None, "args must be provided for training mode"
fastvideo_args = TrainingArgs.from_cli_args(args)
# TODO(will): fix this so that its not so ugly
Expand All @@ -125,15 +152,19 @@ def from_pretrained(cls,
setattr(fastvideo_args, key, value)

fastvideo_args.use_cpu_offload = False
# make sure we are in training mode
fastvideo_args.inference_mode = False
# make sure we are in training mode - note: inference_mode is read-only,
# so we don't set it directly here as it's determined by the mode
# we hijack the precision to be the master weight type so that the
# model is loaded with the correct precision. Subsequently we will
# use FSDP2's MixedPrecisionPolicy to set the precision for the
# fwd, bwd, and other operations' precision.
# fastvideo_args.precision = fastvideo_args.master_weight_type
assert fastvideo_args.pipeline_config.dit_precision == 'fp32', 'only fp32 is supported for training'
# assert fastvideo_args.precision == 'fp32', 'only fp32 is supported for training'
else:
raise ValueError(f"Invalid mode: {mode_str}")

fastvideo_args.check_fastvideo_args()

logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args)

Expand Down
Loading
Loading