diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 87fe89ecd..0d422c143 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -122,17 +122,17 @@ jobs: # Actual tests encoder-test: - 'fastvideo/v1/models/encoders/**' - - 'fastvideo/v1/models/loaders/**' + - 'fastvideo/v1/models/loader/**' - 'fastvideo/v1/tests/encoders/**' - *common-paths vae-test: - 'fastvideo/v1/models/vaes/**' - - 'fastvideo/v1/models/loaders/**' + - 'fastvideo/v1/models/loader/**' - 'fastvideo/v1/tests/vaes/**' - *common-paths transformer-test: - 'fastvideo/v1/models/dits/**' - - 'fastvideo/v1/models/loaders/**' + - 'fastvideo/v1/models/loader/**' - 'fastvideo/v1/tests/transformers/**' - 'fastvideo/v1/layers/**' - 'fastvideo/v1/attention/**' diff --git a/examples/inference/basic/basic.py b/examples/inference/basic/basic.py index 4161004f1..c97f59e56 100644 --- a/examples/inference/basic/basic.py +++ b/examples/inference/basic/basic.py @@ -10,7 +10,7 @@ def main(): # attempt to identify the optimal arguments. generator = VideoGenerator.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", - # if num_gpus > 1, FastVideo will automatically handle distributed setup + # FastVideo will automatically handle distributed setup num_gpus=2, use_fsdp_inference=True, use_cpu_offload=False diff --git a/fastvideo/v1/configs/models/base.py b/fastvideo/v1/configs/models/base.py index 84b0de57c..40eb9ad66 100644 --- a/fastvideo/v1/configs/models/base.py +++ b/fastvideo/v1/configs/models/base.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field, fields -from typing import Any, Dict +from typing import Any, Dict, List, Tuple from fastvideo.v1.logger import init_logger @@ -12,7 +12,9 @@ # 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users @dataclass class ArchConfig: - pass + stacked_params_mapping: List[Tuple[str, str, str]] = field( + default_factory=list + ) # mapping from huggingface weight names to custom names @dataclass diff --git a/fastvideo/v1/configs/models/dits/stepvideo.py b/fastvideo/v1/configs/models/dits/stepvideo.py index abad243ec..78fc6b0b3 100644 --- a/fastvideo/v1/configs/models/dits/stepvideo.py +++ b/fastvideo/v1/configs/models/dits/stepvideo.py @@ -5,13 +5,11 @@ from fastvideo.v1.configs.models.dits.base import DiTArchConfig, DiTConfig -def is_blocks(n: str, m) -> bool: - return "blocks" in n and str.isdigit(n.split(".")[-1]) - - @dataclass class StepVideoArchConfig(DiTArchConfig): - _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks]) + _fsdp_shard_conditions: list = field( + default_factory=lambda: + [lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()]) _param_names_mapping: dict = field( default_factory=lambda: { diff --git a/fastvideo/v1/configs/models/encoders/base.py b/fastvideo/v1/configs/models/encoders/base.py index febbd23f2..d2e686add 100644 --- a/fastvideo/v1/configs/models/encoders/base.py +++ b/fastvideo/v1/configs/models/encoders/base.py @@ -32,8 +32,11 @@ class TextEncoderArchConfig(EncoderArchConfig): output_past: bool = True scalable_attention: bool = True tie_word_embeddings: bool = False - + stacked_params_mapping: List[Tuple[str, str, str]] = field( + default_factory=list + ) # mapping from huggingface weight names to custom names tokenizer_kwargs: Dict[str, Any] = field(default_factory=dict) + _fsdp_shard_conditions: list = field(default_factory=lambda: []) def __post_init__(self) -> None: self.tokenizer_kwargs = { diff --git a/fastvideo/v1/configs/models/encoders/clip.py b/fastvideo/v1/configs/models/encoders/clip.py index 6e81d41e2..ab9340c02 100644 --- a/fastvideo/v1/configs/models/encoders/clip.py +++ b/fastvideo/v1/configs/models/encoders/clip.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Optional +from typing import List, Optional, Tuple from fastvideo.v1.configs.models.encoders.base import (ImageEncoderArchConfig, ImageEncoderConfig, @@ -8,6 +8,14 @@ TextEncoderConfig) +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embeddings") + + @dataclass class CLIPTextArchConfig(TextEncoderArchConfig): vocab_size: int = 49408 @@ -27,6 +35,15 @@ class CLIPTextArchConfig(TextEncoderArchConfig): bos_token_id: int = 49406 eos_token_id: int = 49407 text_len: int = 77 + stacked_params_mapping: List[Tuple[str, str, + str]] = field(default_factory=lambda: [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ]) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer, _is_embeddings]) @dataclass @@ -45,6 +62,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig): attention_dropout: float = 0.0 initializer_range: float = 0.02 initializer_factor: float = 1.0 + stacked_params_mapping: List[Tuple[str, str, + str]] = field(default_factory=lambda: [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ]) @dataclass diff --git a/fastvideo/v1/configs/models/encoders/llama.py b/fastvideo/v1/configs/models/encoders/llama.py index 1fde6e185..0901e98ae 100644 --- a/fastvideo/v1/configs/models/encoders/llama.py +++ b/fastvideo/v1/configs/models/encoders/llama.py @@ -1,11 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Optional +from typing import List, Optional, Tuple from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig, TextEncoderConfig) +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embed_tokens") + + +def _is_final_norm(n: str, m) -> bool: + return n.endswith("norm") + + @dataclass class LlamaArchConfig(TextEncoderArchConfig): vocab_size: int = 32000 @@ -32,6 +44,18 @@ class LlamaArchConfig(TextEncoderArchConfig): head_dim: Optional[int] = None hidden_state_skip_layer: int = 2 text_len: int = 256 + stacked_params_mapping: List[Tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), # type: ignore + (".gate_up_proj", ".up_proj", 1), # type: ignore + ]) + _fsdp_shard_conditions: list = field( + default_factory=lambda: + [_is_transformer_layer, _is_embeddings, _is_final_norm]) @dataclass diff --git a/fastvideo/v1/configs/models/encoders/t5.py b/fastvideo/v1/configs/models/encoders/t5.py index 7ec4d4a1b..79e9c9ad0 100644 --- a/fastvideo/v1/configs/models/encoders/t5.py +++ b/fastvideo/v1/configs/models/encoders/t5.py @@ -1,11 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Optional +from typing import List, Optional, Tuple from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig, TextEncoderConfig) +def _is_transformer_layer(n: str, m) -> bool: + return "block" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("shared") + + +def _is_final_layernorm(n: str, m) -> bool: + return n.endswith("final_layer_norm") + + @dataclass class T5ArchConfig(TextEncoderArchConfig): vocab_size: int = 32128 @@ -29,6 +41,16 @@ class T5ArchConfig(TextEncoderArchConfig): eos_token_id: int = 1 classifier_dropout: float = 0.0 text_len: int = 512 + stacked_params_mapping: List[Tuple[str, str, + str]] = field(default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q", "q"), + (".qkv_proj", ".k", "k"), + (".qkv_proj", ".v", "v"), + ]) + _fsdp_shard_conditions: list = field( + default_factory=lambda: + [_is_transformer_layer, _is_embeddings, _is_final_layernorm]) # Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py def __post_init__(self): diff --git a/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_iterable_style.py b/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_iterable_style.py index 57cf092ca..f808d2f09 100644 --- a/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_iterable_style.py +++ b/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_iterable_style.py @@ -11,7 +11,7 @@ build_parquet_iterable_style_dataloader) from fastvideo.v1.distributed import get_world_rank from fastvideo.v1.distributed.parallel_state import ( - cleanup_dist_env_and_memory, get_torch_device, + cleanup_dist_env_and_memory, get_local_torch_device, maybe_init_distributed_environment_and_model_parallel) from fastvideo.v1.logger import init_logger @@ -148,8 +148,8 @@ def main() -> None: break # Move data to device - latents = latents.to(get_torch_device()) - embeddings = embeddings.to(get_torch_device()) + latents = latents.to(get_local_torch_device()) + embeddings = embeddings.to(get_local_torch_device()) # Calculate actual batch size batch_size = latents.size(0) diff --git a/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_map_style.py b/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_map_style.py index a2614edda..7618471ea 100644 --- a/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_map_style.py +++ b/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_map_style.py @@ -13,7 +13,7 @@ build_parquet_map_style_dataloader) from fastvideo.v1.distributed import get_world_rank from fastvideo.v1.distributed.parallel_state import ( - cleanup_dist_env_and_memory, get_torch_device, + cleanup_dist_env_and_memory, get_local_torch_device, maybe_init_distributed_environment_and_model_parallel) from fastvideo.v1.logger import init_logger @@ -165,8 +165,8 @@ def main() -> None: break # Move data to device - latents = latents.to(get_torch_device()) - embeddings = embeddings.to(get_torch_device()) + latents = latents.to(get_local_torch_device()) + embeddings = embeddings.to(get_local_torch_device()) # Calculate actual batch size batch_size = latents.size(0) diff --git a/fastvideo/v1/distributed/__init__.py b/fastvideo/v1/distributed/__init__.py index 5c0a1af6e..7e96bafa1 100644 --- a/fastvideo/v1/distributed/__init__.py +++ b/fastvideo/v1/distributed/__init__.py @@ -3,10 +3,10 @@ from fastvideo.v1.distributed.communication_op import * from fastvideo.v1.distributed.parallel_state import ( cleanup_dist_env_and_memory, get_dp_group, get_dp_rank, get_dp_world_size, - get_sp_group, get_sp_parallel_rank, get_sp_world_size, get_torch_device, - get_tp_group, get_tp_rank, get_tp_world_size, get_world_group, - get_world_rank, get_world_size, init_distributed_environment, - initialize_model_parallel, + get_local_torch_device, get_sp_group, get_sp_parallel_rank, + get_sp_world_size, get_tp_group, get_tp_rank, get_tp_world_size, + get_world_group, get_world_rank, get_world_size, + init_distributed_environment, initialize_model_parallel, maybe_init_distributed_environment_and_model_parallel, model_parallel_is_initialized) from fastvideo.v1.distributed.utils import * @@ -40,5 +40,5 @@ "get_tp_world_size", # Get torch device - "get_torch_device", + "get_local_torch_device", ] diff --git a/fastvideo/v1/distributed/parallel_state.py b/fastvideo/v1/distributed/parallel_state.py index b15a9f6c0..9c6a9ff37 100644 --- a/fastvideo/v1/distributed/parallel_state.py +++ b/fastvideo/v1/distributed/parallel_state.py @@ -36,6 +36,7 @@ import torch import torch.distributed +import torch.distributed as dist from torch.distributed import Backend, ProcessGroup, ReduceOp import fastvideo.v1.envs as envs @@ -692,6 +693,7 @@ def destroy(self) -> None: _WORLD: Optional[GroupCoordinator] = None +_NODE: Optional[GroupCoordinator] = None def get_world_group() -> GroupCoordinator: @@ -699,6 +701,11 @@ def get_world_group() -> GroupCoordinator: return _WORLD +def get_node_group() -> GroupCoordinator: + assert _NODE is not None, ("node group is not initialized") + return _NODE + + def init_world_group(ranks: List[int], local_rank: int, backend: str) -> GroupCoordinator: return GroupCoordinator( @@ -710,6 +717,18 @@ def init_world_group(ranks: List[int], local_rank: int, ) +def init_node_group(local_rank: int, backend: str): + cpu_group = get_world_group().cpu_group + node_ranks = same_node_ranks(cpu_group) + node_size = len(node_ranks) + all_node_ranks = [ + list(range(i * node_size, (i + 1) * node_size)) + for i in range(dist.get_world_size() // node_size) + ] + global _NODE + _NODE = init_model_parallel_group(all_node_ranks, local_rank, backend) + + def init_model_parallel_group( group_ranks: List[List[int]], local_rank: int, @@ -782,6 +801,8 @@ def init_distributed_environment( else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") + # Init a group for each node + init_node_group(local_rank, backend) _SP: Optional[GroupCoordinator] = None @@ -904,7 +925,7 @@ def get_dp_rank() -> int: return get_dp_group().rank_in_group -def get_torch_device() -> torch.device: +def get_local_torch_device() -> torch.device: """Return the torch device for the current rank.""" return torch.device(f"cuda:{envs.LOCAL_RANK}") @@ -1021,17 +1042,22 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): "torch._C._host_emptyCache() only available in Pytorch >=2.5") -def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], - source_rank: int = 0) -> List[bool]: +def same_node_ranks(pg: Union[ProcessGroup, StatelessProcessGroup], + source_rank: int = 0) -> List[int]: """ - This is a collective operation that returns if each rank is in the same node + This is a collective operation that returns ranks that are in the same node as the source rank. It tests if processes are attached to the same memory system (shared access to shared memory). + Args: + pg: the global process group to test + source_rank: the rank to test against + Returns: + A list of ranks that are in the same node as the source rank. """ if isinstance(pg, ProcessGroup): assert torch.distributed.get_backend( pg) != torch.distributed.Backend.NCCL, ( - "in_the_same_node_as should be tested with a non-NCCL group.") + "same_node_ranks should be tested with a non-NCCL group.") # local rank inside the group rank = torch.distributed.get_rank(group=pg) world_size = torch.distributed.get_world_size(group=pg) @@ -1103,7 +1129,7 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], rank_data = pg.broadcast_obj(is_in_the_same_node, src=i) aggregated_data += rank_data - return [x == 1 for x in aggregated_data.tolist()] + return [i for i, x in enumerate(aggregated_data.tolist()) if x == 1] def initialize_tensor_parallel_group( diff --git a/fastvideo/v1/fastvideo_args.py b/fastvideo/v1/fastvideo_args.py index 3cdeecbe2..633935451 100644 --- a/fastvideo/v1/fastvideo_args.py +++ b/fastvideo/v1/fastvideo_args.py @@ -58,8 +58,10 @@ class FastVideoArgs: output_type: str = "pil" - use_cpu_offload: bool = True + use_cpu_offload: bool = True # For DiT use_fsdp_inference: bool = True + text_encoder_offload: bool = True + pin_cpu_memory: bool = True # STA (Sliding Tile Attention) parameters mask_strategy_file_path: Optional[str] = None @@ -208,7 +210,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--use-cpu-offload", action=StoreBoolean, help= - "Use CPU offload for model inference. Enable if run out of memory with FSDP.", + "Use CPU offload for DiT inference. Enable if run out of memory with FSDP.", ) parser.add_argument( "--use-fsdp-inference", @@ -216,7 +218,19 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.", ) - + parser.add_argument( + "--text-encoder-cpu-offload", + action=StoreBoolean, + help= + "Use CPU offload for text encoder. Enable if run out of memory.", + ) + parser.add_argument( + "--pin-cpu-memory", + action=StoreBoolean, + help= + "Pin memory for CPU offload. Only added as a temp workaround if it throws \"CUDA error: invalid argument\". " + "Should be enabled in almost all cases", + ) parser.add_argument( "--disable-autocast", action=StoreBoolean, diff --git a/fastvideo/v1/layers/layernorm.py b/fastvideo/v1/layers/layernorm.py index 66032ecd0..1982611c0 100644 --- a/fastvideo/v1/layers/layernorm.py +++ b/fastvideo/v1/layers/layernorm.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributed.tensor import DTensor from fastvideo.v1.layers.custom_op import CustomOp @@ -70,7 +71,12 @@ def forward_native( x = x * torch.rsqrt(variance + self.variance_epsilon) x = x.to(orig_dtype) if self.has_weight: - x = x * self.weight + # TODO(wenxuan): When using CPU offload, FSDP has a bug that doesn't unwrap DTensor in final_layer_norm. + # Report this + if isinstance(self.weight, DTensor): + x = x * self.weight.to_local().to(x.device) + else: + x = x * self.weight if residual is None: return x else: diff --git a/fastvideo/v1/models/dits/stepvideo.py b/fastvideo/v1/models/dits/stepvideo.py index c70f1c090..d0ad9854a 100644 --- a/fastvideo/v1/models/dits/stepvideo.py +++ b/fastvideo/v1/models/dits/stepvideo.py @@ -455,10 +455,7 @@ def forward(self, class StepVideoModel(BaseDiT): # (Optional) Keep the same attribute for compatibility with splitting, etc. - _fsdp_shard_conditions = [ - lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit(), - # lambda n, m: "pos_embed" in n # If needed for the patch embedding. - ] + _fsdp_shard_conditions = StepVideoConfig()._fsdp_shard_conditions _param_names_mapping = StepVideoConfig()._param_names_mapping _reverse_param_names_mapping = StepVideoConfig( )._reverse_param_names_mapping diff --git a/fastvideo/v1/models/encoders/base.py b/fastvideo/v1/models/encoders/base.py index 4c7c45ec2..69b3a4846 100644 --- a/fastvideo/v1/models/encoders/base.py +++ b/fastvideo/v1/models/encoders/base.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Optional, Tuple +from dataclasses import field +from typing import List, Optional, Tuple import torch from torch import nn @@ -12,6 +13,9 @@ class TextEncoder(nn.Module, ABC): + _fsdp_shard_conditions: list = field(default_factory=lambda: []) + _stacked_params_mapping: List[Tuple[str, str, + str]] = field(default_factory=list) _supported_attention_backends: Tuple[ AttentionBackendEnum, ...] = TextEncoderConfig()._supported_attention_backends @@ -19,6 +23,8 @@ class TextEncoder(nn.Module, ABC): def __init__(self, config: TextEncoderConfig) -> None: super().__init__() self.config = config + self._fsdp_shard_conditions = config._fsdp_shard_conditions + self._stacked_params_mapping = config.arch_config.stacked_params_mapping if not self.supported_attention_backends: raise ValueError( f"Subclass {self.__class__.__name__} must define _supported_attention_backends" diff --git a/fastvideo/v1/models/encoders/clip.py b/fastvideo/v1/models/encoders/clip.py index ecbaba58d..8278e0f71 100644 --- a/fastvideo/v1/models/encoders/clip.py +++ b/fastvideo/v1/models/encoders/clip.py @@ -596,12 +596,7 @@ def device(self): # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] + params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() layer_count = len(self.vision_model.encoder.layers) @@ -620,7 +615,8 @@ def load_weights(self, weights: Iterable[Tuple[str, if layer_idx >= layer_count: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for (param_name, weight_name, + shard_id) in self.config.arch_config.stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/fastvideo/v1/models/encoders/llama.py b/fastvideo/v1/models/encoders/llama.py index ebf009bf1..2fa32780d 100644 --- a/fastvideo/v1/models/encoders/llama.py +++ b/fastvideo/v1/models/encoders/llama.py @@ -369,14 +369,7 @@ def forward( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] + params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: @@ -406,7 +399,7 @@ def load_weights(self, weights: Iterable[Tuple[str, continue else: name = kv_scale_name - for param_name, weight_name, shard_id in stacked_params_mapping: + for param_name, weight_name, shard_id in self.config.arch_config.stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/fastvideo/v1/models/encoders/t5.py b/fastvideo/v1/models/encoders/t5.py index a4ea46c40..8fa775c12 100644 --- a/fastvideo/v1/models/encoders/t5.py +++ b/fastvideo/v1/models/encoders/t5.py @@ -494,7 +494,7 @@ def forward( attention_mask=attention_mask, attn_metadata=attn_metadata, ) - hidden_states = self.final_layer_norm.forward_native(hidden_states) + hidden_states = self.final_layer_norm.forward(hidden_states) return hidden_states @@ -631,19 +631,13 @@ def forward( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q", "q"), - (".qkv_proj", ".k", "k"), - (".qkv_proj", ".v", "v"), - ] params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: loaded = False if "decoder" in name or "lm_head" in name: continue - for param_name, weight_name, shard_id in stacked_params_mapping: + for param_name, weight_name, shard_id in self.config.arch_config.stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/fastvideo/v1/models/loader/component_loader.py b/fastvideo/v1/models/loader/component_loader.py index 270bc8387..41a40083e 100644 --- a/fastvideo/v1/models/loader/component_loader.py +++ b/fastvideo/v1/models/loader/component_loader.py @@ -10,17 +10,20 @@ from typing import Any, Generator, Iterable, List, Optional, Tuple, cast import torch +import torch.distributed as dist import torch.nn as nn from safetensors.torch import load_file as safetensors_load_file from transformers import AutoImageProcessor, AutoTokenizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from fastvideo.v1.configs.models import EncoderConfig -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.hf_transformer_utils import get_diffusers_config -from fastvideo.v1.models.loader.fsdp_load import maybe_load_fsdp_model +from fastvideo.v1.models.loader.fsdp_load import (init_device_mesh, + maybe_load_fsdp_model, + shard_model) from fastvideo.v1.models.loader.utils import set_default_torch_dtype from fastvideo.v1.models.loader.weight_utils import ( filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, @@ -163,16 +166,19 @@ def _prepare_weights( return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, source: "Source" + self, + source: "Source", + to_cpu: bool = True ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.fall_back_to_pt, source.allow_patterns_overrides) if use_safetensors: - weights_iterator = safetensors_weights_iterator(hf_weights_files) + weights_iterator = safetensors_weights_iterator( + hf_weights_files, to_cpu) else: - weights_iterator = pt_weights_iterator(hf_weights_files) + weights_iterator = pt_weights_iterator(hf_weights_files, to_cpu) if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() @@ -181,10 +187,11 @@ def _get_weights_iterator( for (name, tensor) in weights_iterator) def _get_all_weights( - self, - model_config: Any, - model: nn.Module, - model_path: str, + self, + model_config: Any, + model: nn.Module, + model_path: str, + to_cpu: bool = True ) -> Generator[Tuple[str, torch.Tensor], None, None]: primary_weights = TextEncoderLoader.Source( model_path, @@ -193,14 +200,14 @@ def _get_all_weights( allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), ) - yield from self._get_weights_iterator(primary_weights) + yield from self._get_weights_iterator(primary_weights, to_cpu) secondary_weights = cast( Iterable[TextEncoderLoader.Source], getattr(model, "secondary_weights", ()), ) for source in secondary_weights: - yield from self._get_weights_iterator(source) + yield from self._get_weights_iterator(source, to_cpu) def load(self, model_path: str, architecture: str, fastvideo_args: FastVideoArgs): @@ -233,16 +240,22 @@ def load(self, model_path: str, architecture: str, encoder_precision = fastvideo_args.pipeline_config.text_encoder_precisions[ 1] - target_device = get_torch_device() + target_device = get_local_torch_device() # TODO(will): add support for other dtypes return self.load_model(model_path, encoder_config, target_device, - encoder_precision) + fastvideo_args, encoder_precision) def load_model(self, model_path: str, model_config: EncoderConfig, target_device: torch.device, + fastvideo_args: FastVideoArgs, dtype: str = "fp16"): + use_cpu_offload = fastvideo_args.text_encoder_offload and len( + getattr(model_config, "_fsdp_shard_conditions", [])) > 0 + + if fastvideo_args.text_encoder_offload: + target_device = torch.device("cpu") with set_default_torch_dtype(PRECISION_TO_TYPE[dtype]): with target_device: architectures = getattr(model_config, "architectures", []) @@ -251,12 +264,26 @@ def load_model(self, weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( - self._get_all_weights(model_config, model, model_path)) + self._get_all_weights(model_config, model, model_path, + use_cpu_offload)) self.counter_after_loading_weights = time.perf_counter() logger.info( "Loading weights took %.2f seconds", self.counter_after_loading_weights - self.counter_before_loading_weights) + + if use_cpu_offload: + mesh = init_device_mesh( + "cuda", + mesh_shape=(1, dist.get_world_size()), + mesh_dim_names=("offload", "replicate"), + ) + shard_model(model, + cpu_offload=True, + reshard_after_forward=True, + mesh=mesh["offload"], + fsdp_shard_conditions=model._fsdp_shard_conditions, + pin_cpu_memory=fastvideo_args.pin_cpu_memory) # We only enable strict check for non-quantized models # that have loaded weights tracking currently. # if loaded_weights is not None: @@ -290,10 +317,10 @@ def load(self, model_path: str, architecture: str, encoder_config = fastvideo_args.pipeline_config.image_encoder_config encoder_config.update_model_arch(model_config) - target_device = get_torch_device() + target_device = get_local_torch_device() # TODO(will): add support for other dtypes return self.load_model( - model_path, encoder_config, target_device, + model_path, encoder_config, target_device, fastvideo_args, fastvideo_args.pipeline_config.image_encoder_precision) @@ -346,7 +373,7 @@ def load(self, model_path: str, architecture: str, with set_default_torch_dtype(PRECISION_TO_TYPE[ fastvideo_args.pipeline_config.vae_precision]): vae_cls, _ = ModelRegistry.resolve_model_cls(class_name) - vae = vae_cls(vae_config).to(get_torch_device()) + vae = vae_cls(vae_config).to(get_local_torch_device()) # Find all safetensors files safetensors_list = glob.glob( @@ -405,7 +432,7 @@ def load(self, model_path: str, architecture: str, "hf_config": hf_config }, weight_dir_list=safetensors_list, - device=get_torch_device(), + device=get_local_torch_device(), hsdp_replicate_dim=fastvideo_args.hsdp_replicate_dim, hsdp_shard_dim=fastvideo_args.hsdp_shard_dim, cpu_offload=fastvideo_args.use_cpu_offload, diff --git a/fastvideo/v1/models/loader/fsdp_load.py b/fastvideo/v1/models/loader/fsdp_load.py index a9c890f69..a8b91d5b9 100644 --- a/fastvideo/v1/models/loader/fsdp_load.py +++ b/fastvideo/v1/models/loader/fsdp_load.py @@ -69,6 +69,7 @@ def maybe_load_fsdp_model( fsdp_inference: bool = False, output_dtype: Optional[torch.dtype] = None, training_mode: bool = True, + pin_cpu_memory: bool = True, ) -> torch.nn.Module: """ Load the model with FSDP if is training, else load the model without FSDP. @@ -101,9 +102,12 @@ def maybe_load_fsdp_model( cpu_offload=cpu_offload, reshard_after_forward=True, mp_policy=mp_policy, - mesh=device_mesh) + mesh=device_mesh, + fsdp_shard_conditions=model._fsdp_shard_conditions, + pin_cpu_memory=pin_cpu_memory) - weight_iterator = safetensors_weights_iterator(weight_dir_list) + weight_iterator = safetensors_weights_iterator( + weight_dir_list, to_cpu=cpu_offload, async_broadcast=not cpu_offload) param_names_mapping_fn = get_param_names_mapping(model._param_names_mapping) load_model_from_full_model_state_dict( model, @@ -126,12 +130,13 @@ def maybe_load_fsdp_model( def shard_model( model, - *, cpu_offload: bool, reshard_after_forward: bool = True, - mp_policy: Optional[MixedPrecisionPolicy] = None, - dp_mesh: Optional[DeviceMesh] = None, + mp_policy: Optional[MixedPrecisionPolicy] = MixedPrecisionPolicy(), # noqa mesh: Optional[DeviceMesh] = None, + fsdp_shard_conditions: Optional[List[Callable[[str, nn.Module], + bool]]] = None, + pin_cpu_memory: bool = True, ) -> None: """ Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API. @@ -150,19 +155,28 @@ def shard_model( reshard_after_forward (bool): Whether to reshard parameters and buffers after the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy. - dp_mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism. + mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism. Default to None. + fsdp_shard_conditions (Optional[List[Callable[[str, nn.Module], bool]]]): A list of functions to determine + which modules to shard with FSDP. Raises: ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered. """ + if fsdp_shard_conditions is None or len(fsdp_shard_conditions) == 0: + logger.warning( + "The FSDP shard condition list is empty or None. No modules will be sharded in %s", + type(model).__name__) + return + fsdp_kwargs = { "reshard_after_forward": reshard_after_forward, "mesh": mesh, "mp_policy": mp_policy, } if cpu_offload: - fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() + fsdp_kwargs["offload_policy"] = CPUOffloadPolicy( + pin_memory=pin_cpu_memory) # iterating in reverse to start with # lowest-level modules first @@ -172,7 +186,7 @@ def shard_model( for n, m in reversed(list(model.named_modules())): if any([ shard_condition(n, m) - for shard_condition in model._fsdp_shard_conditions + for shard_condition in fsdp_shard_conditions ]): fully_shard(m, **fsdp_kwargs) num_layers_sharded += 1 @@ -181,7 +195,6 @@ def shard_model( raise ValueError( "No layer modules were sharded. Please check if shard conditions are working as expected." ) - # Finally shard the entire model to account for any stragglers fully_shard(model, **fsdp_kwargs) @@ -224,6 +237,9 @@ def load_model_from_full_model_state_dict( to_merge_params: DefaultDict[str, Dict[Any, Any]] = defaultdict(dict) reverse_param_names_mapping = {} assert param_names_mapping is not None + + # iterate over all the weights to sync broadcast before use + full_sd_iterator = list(full_sd_iterator) # type: ignore for source_param_name, full_tensor in full_sd_iterator: target_param_name, merge_index, num_params_to_merge = param_names_mapping( source_param_name) diff --git a/fastvideo/v1/models/loader/weight_utils.py b/fastvideo/v1/models/loader/weight_utils.py index b939ab5c5..00e1b4972 100644 --- a/fastvideo/v1/models/loader/weight_utils.py +++ b/fastvideo/v1/models/loader/weight_utils.py @@ -11,9 +11,11 @@ import filelock import huggingface_hub.constants import torch +import torch.distributed as dist from safetensors.torch import safe_open from tqdm.auto import tqdm +from fastvideo.v1.distributed.parallel_state import get_node_group from fastvideo.v1.logger import init_logger logger = init_logger(__name__) @@ -118,36 +120,79 @@ def filter_files_not_needed_for_inference( def safetensors_weights_iterator( - hf_weights_files: List[str] + hf_weights_files: List[str], + to_cpu: bool = False, + async_broadcast: bool = False ) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Iterate over the weights in the model safetensor files.""" - enable_tqdm = not torch.distributed.is_initialized( - ) or torch.distributed.get_rank() == 0 + """Iterate over the weights in the model safetensor files. + Args: + hf_weights_files: List of safetensor files to load. + to_cpu: Whether to load the weights to CPU. If False, will load to the GPU device bound to the current process. + async_broadcast: Whether to overlap loading from disk and broadcasting to other ranks. If True, + must iterate over all the weights before use. Only use if to_cpu is False. + """ + local_rank = get_node_group().rank + device = f"cuda:{local_rank}" if not to_cpu else "cpu" + enable_tqdm = not torch.distributed.is_initialized() or get_node_group( + ).rank == 0 + assert not (async_broadcast + and to_cpu), "Cannot broadcast weights when loading to CPU" + + handles = [] for st_file in tqdm( hf_weights_files, desc="Loading safetensors checkpoint shards", disable=not enable_tqdm, bar_format=_BAR_FORMAT, ): - with safe_open(st_file, framework="pt") as f: + with safe_open(st_file, framework="pt", device=device) as f: for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) + if to_cpu: + param = f.get_tensor(name) + else: + if local_rank == 0: + param = f.get_tensor(name) + else: + shape = f.get_slice(name).get_shape() + param = torch.empty(shape, device=device) + # broadcast to local ranks + # TODO(Wenxuan): scatter instead of broadcast + if get_node_group().world_size > 1: + group = get_node_group().device_group + if async_broadcast: + handle = dist.broadcast(param, + src=dist.get_global_rank( + group, 0), + async_op=True, + group=group) + handles.append(handle) + else: + dist.broadcast(param, + src=dist.get_global_rank(group, 0), + group=group) yield name, param + if async_broadcast: + for handle in handles: + handle.wait() + def pt_weights_iterator( - hf_weights_files: List[str] + hf_weights_files: List[str], + to_cpu: bool = True # default to CPU for text encoder ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" - enable_tqdm = not torch.distributed.is_initialized( - ) or torch.distributed.get_rank() == 0 + local_rank = get_node_group().rank + device = f"cuda:{local_rank}" if not to_cpu else "cpu" + enable_tqdm = not torch.distributed.is_initialized() or get_node_group( + ).rank == 0 for bin_file in tqdm( hf_weights_files, desc="Loading pt checkpoint shards", disable=not enable_tqdm, bar_format=_BAR_FORMAT, ): - state = torch.load(bin_file, map_location="cpu", weights_only=True) + state = torch.load(bin_file, map_location=device, weights_only=True) yield from state.items() del state diff --git a/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_base.py b/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_base.py index 5866c46ee..8db6f1d28 100644 --- a/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_base.py +++ b/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_base.py @@ -18,7 +18,7 @@ from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset import ValidationDataset, getdataset from fastvideo.v1.dataset.preprocessing_datasets import PreprocessBatch -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines.composed_pipeline_base import ComposedPipelineBase @@ -328,7 +328,8 @@ def preprocess_video_and_text(self, fastvideo_args: FastVideoArgs, args): # VAE with torch.autocast("cuda", dtype=torch.float32): latents = self.get_module("vae").encode( - valid_data["pixel_values"].to(get_torch_device())).mean + valid_data["pixel_values"].to( + get_local_torch_device())).mean # Get extra features if needed extra_features = self.get_extra_features( diff --git a/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_i2v.py b/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_i2v.py index 58ba09bef..286aa9dfd 100644 --- a/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_i2v.py +++ b/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_i2v.py @@ -13,7 +13,7 @@ from PIL import Image from fastvideo.v1.dataset.dataloader.schema import pyarrow_schema_i2v -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.models.vision_utils import (get_default_height_width, @@ -82,8 +82,8 @@ def get_extra_features(self, valid_data: Dict[str, Any], fastvideo_args: FastVideoArgs) -> Dict[str, Any]: # TODO(will): move these to cpu at some point - self.get_module("image_encoder").to(get_torch_device()) - self.get_module("vae").to(get_torch_device()) + self.get_module("image_encoder").to(get_local_torch_device()) + self.get_module("vae").to(get_local_torch_device()) features = {} """Get CLIP features from the first frame of each video.""" @@ -107,7 +107,7 @@ def get_extra_features(self, valid_data: Dict[str, Any], # Get CLIP features pixel_values = torch.cat( [img['pixel_values'] for img in processed_images], - dim=0).to(get_torch_device()) + dim=0).to(get_local_torch_device()) with torch.no_grad(): image_inputs = {'pixel_values': pixel_values} with set_forward_context(current_timestep=0, attn_metadata=None): @@ -129,8 +129,8 @@ def get_extra_features(self, valid_data: Dict[str, Any], height, width) ], dim=2) - video_condition = video_condition.to(device=get_torch_device(), - dtype=torch.float32) + video_condition = video_condition.to( + device=get_local_torch_device(), dtype=torch.float32) video_conditions.append(video_condition) video_conditions = torch.cat(video_conditions, dim=0) diff --git a/fastvideo/v1/pipelines/stages/decoding.py b/fastvideo/v1/pipelines/stages/decoding.py index ea75f7473..2043668b6 100644 --- a/fastvideo/v1/pipelines/stages/decoding.py +++ b/fastvideo/v1/pipelines/stages/decoding.py @@ -5,7 +5,7 @@ import torch -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.vaes.common import ParallelTiledVAE @@ -61,7 +61,7 @@ def forward( Returns: The batch with decoded outputs. """ - self.vae = self.vae.to(get_torch_device()) + self.vae = self.vae.to(get_local_torch_device()) latents = batch.latents # TODO(will): remove this once we add input/output validation for stages diff --git a/fastvideo/v1/pipelines/stages/denoising.py b/fastvideo/v1/pipelines/stages/denoising.py index 2070bf93e..6fc7b79c2 100644 --- a/fastvideo/v1/pipelines/stages/denoising.py +++ b/fastvideo/v1/pipelines/stages/denoising.py @@ -12,8 +12,9 @@ from fastvideo.v1.attention import get_attn_backend from fastvideo.v1.configs.pipelines.base import STA_Mode -from fastvideo.v1.distributed import (get_sp_parallel_rank, get_sp_world_size, - get_torch_device, get_world_group) +from fastvideo.v1.distributed import (get_local_torch_device, + get_sp_parallel_rank, get_sp_world_size, + get_world_group) from fastvideo.v1.distributed.communication_op import ( sequence_model_parallel_all_gather) from fastvideo.v1.fastvideo_args import FastVideoArgs @@ -192,7 +193,7 @@ def forward( [fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0], dtype=torch.float32, - device=get_torch_device(), + device=get_local_torch_device(), ).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None) diff --git a/fastvideo/v1/pipelines/stages/encoding.py b/fastvideo/v1/pipelines/stages/encoding.py index 33bd76dca..410dc2aee 100644 --- a/fastvideo/v1/pipelines/stages/encoding.py +++ b/fastvideo/v1/pipelines/stages/encoding.py @@ -7,7 +7,7 @@ import PIL.Image import torch -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.vaes.common import ParallelTiledVAE @@ -49,7 +49,7 @@ def forward( Returns: The batch with encoded outputs. """ - self.vae = self.vae.to(get_torch_device()) + self.vae = self.vae.to(get_local_torch_device()) assert batch.height is not None assert batch.width is not None @@ -65,7 +65,8 @@ def forward( image, vae_scale_factor=self.vae.spatial_compression_ratio, height=batch.height, - width=batch.width).to(get_torch_device(), dtype=torch.float32) + width=batch.width).to(get_local_torch_device(), + dtype=torch.float32) image = image.unsqueeze(2) else: @@ -78,7 +79,7 @@ def forward( batch.num_frames - 1, batch.height, batch.width) ], dim=2) - video_condition = video_condition.to(device=get_torch_device(), + video_condition = video_condition.to(device=get_local_torch_device(), dtype=torch.float32) # Setup VAE precision diff --git a/fastvideo/v1/pipelines/stages/image_encoding.py b/fastvideo/v1/pipelines/stages/image_encoding.py index 27cd03605..1dd3f87ab 100644 --- a/fastvideo/v1/pipelines/stages/image_encoding.py +++ b/fastvideo/v1/pipelines/stages/image_encoding.py @@ -7,7 +7,7 @@ import torch -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger @@ -55,12 +55,12 @@ def forward( The batch with encoded prompt embeddings. """ if fastvideo_args.use_cpu_offload: - self.image_encoder = self.image_encoder.to(get_torch_device()) + self.image_encoder = self.image_encoder.to(get_local_torch_device()) image = batch.pil_image image_inputs = self.image_processor( - images=image, return_tensors="pt").to(get_torch_device()) + images=image, return_tensors="pt").to(get_local_torch_device()) with set_forward_context(current_timestep=0, attn_metadata=None): outputs = self.image_encoder(**image_inputs) image_embeds = outputs.last_hidden_state diff --git a/fastvideo/v1/pipelines/stages/latent_preparation.py b/fastvideo/v1/pipelines/stages/latent_preparation.py index 2926a53bd..2142edc4e 100644 --- a/fastvideo/v1/pipelines/stages/latent_preparation.py +++ b/fastvideo/v1/pipelines/stages/latent_preparation.py @@ -5,7 +5,7 @@ from diffusers.utils.torch_utils import randn_tensor -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -62,7 +62,7 @@ def forward( # Get required parameters dtype = batch.prompt_embeds[0].dtype - device = get_torch_device() + device = get_local_torch_device() generator = batch.generator latents = batch.latents num_frames = latent_num_frames if latent_num_frames is not None else batch.num_frames diff --git a/fastvideo/v1/pipelines/stages/text_encoding.py b/fastvideo/v1/pipelines/stages/text_encoding.py index 4bf4ef2f4..0e5e14125 100644 --- a/fastvideo/v1/pipelines/stages/text_encoding.py +++ b/fastvideo/v1/pipelines/stages/text_encoding.py @@ -5,9 +5,7 @@ This module contains implementations of prompt encoding stages for diffusion pipelines. """ -import torch - -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -62,8 +60,6 @@ def forward( fastvideo_args.pipeline_config.text_encoder_configs, fastvideo_args.pipeline_config.preprocess_text_funcs, fastvideo_args.pipeline_config.postprocess_text_funcs): - if fastvideo_args.use_cpu_offload: - text_encoder = text_encoder.to(get_torch_device()) assert isinstance(batch.prompt, (str, list)) if isinstance(batch.prompt, str): @@ -71,8 +67,9 @@ def forward( texts = [] for prompt_str in batch.prompt: texts.append(preprocess_func(prompt_str)) - text_inputs = tokenizer( - texts, **encoder_config.tokenizer_kwargs).to(get_torch_device()) + text_inputs = tokenizer(texts, + **encoder_config.tokenizer_kwargs).to( + get_local_torch_device()) input_ids = text_inputs["input_ids"] attention_mask = text_inputs["attention_mask"] with set_forward_context(current_timestep=0, attn_metadata=None): @@ -91,8 +88,8 @@ def forward( assert isinstance(batch.negative_prompt, str) negative_text = preprocess_func(batch.negative_prompt) negative_text_inputs = tokenizer( - negative_text, - **encoder_config.tokenizer_kwargs).to(get_torch_device()) + negative_text, **encoder_config.tokenizer_kwargs).to( + get_local_torch_device()) negative_input_ids = negative_text_inputs["input_ids"] negative_attention_mask = negative_text_inputs["attention_mask"] with set_forward_context(current_timestep=0, @@ -110,10 +107,6 @@ def forward( batch.negative_attention_mask.append( negative_attention_mask) - if fastvideo_args.use_cpu_offload: - text_encoder.to('cpu') - torch.cuda.empty_cache() - return batch def verify_input(self, batch: ForwardBatch, diff --git a/fastvideo/v1/pipelines/stages/timestep_preparation.py b/fastvideo/v1/pipelines/stages/timestep_preparation.py index d30134a47..475a0ef31 100644 --- a/fastvideo/v1/pipelines/stages/timestep_preparation.py +++ b/fastvideo/v1/pipelines/stages/timestep_preparation.py @@ -7,7 +7,7 @@ import inspect -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -45,7 +45,7 @@ def forward( The batch with prepared timesteps. """ scheduler = self.scheduler - device = get_torch_device() + device = get_local_torch_device() num_inference_steps = batch.num_inference_steps timesteps = batch.timesteps sigmas = batch.sigmas diff --git a/fastvideo/v1/pipelines/stepvideo/stepvideo_pipeline.py b/fastvideo/v1/pipelines/stepvideo/stepvideo_pipeline.py index ba3d10172..5f9a5ab1f 100644 --- a/fastvideo/v1/pipelines/stepvideo/stepvideo_pipeline.py +++ b/fastvideo/v1/pipelines/stepvideo/stepvideo_pipeline.py @@ -14,7 +14,7 @@ import torch from huggingface_hub import hf_hub_download -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.encoders.bert import HunyuanClip # type: ignore @@ -78,7 +78,7 @@ def initialize_pipeline(self, fastvideo_args: FastVideoArgs): """ Initialize the pipeline. """ - target_device = get_torch_device() + target_device = get_local_torch_device() llm_dir = os.path.join(self.model_path, "step_llm") clip_dir = os.path.join(self.model_path, "hunyuan_clip") text_enc = self.build_llm(llm_dir, target_device) diff --git a/fastvideo/v1/tests/encoders/test_clip_encoder.py b/fastvideo/v1/tests/encoders/test_clip_encoder.py index 9a65e87b4..5c185ae0c 100644 --- a/fastvideo/v1/tests/encoders/test_clip_encoder.py +++ b/fastvideo/v1/tests/encoders/test_clip_encoder.py @@ -6,7 +6,7 @@ import pytest import torch from transformers import AutoConfig - +import gc from fastvideo.models.hunyuan.text_encoder import (load_text_encoder, load_tokenizer) # from fastvideo.v1.models.hunyuan.text_encoder import load_text_encoder, load_tokenizer @@ -16,6 +16,8 @@ from fastvideo.v1.logger import init_logger from fastvideo.v1.utils import maybe_download_model from fastvideo.v1.configs.models.encoders import CLIPTextConfig +from torch.distributed.tensor import DTensor +from torch.testing import assert_close logger = init_logger(__name__) @@ -66,7 +68,6 @@ def test_clip_encoder(): # Load the HuggingFace implementation directly # model2 = CLIPTextModel(hf_config) # model2 = model2.to(torch.float16) - model2 = model2.to(device) model2.eval() # Sanity check weights between the two models @@ -78,19 +79,20 @@ def test_clip_encoder(): logger.info("Model1 has %d parameters", len(params1)) logger.info("Model2 has %d parameters", len(params2)) - # Compare a few key parameters - - # weight_diffs = [] - # for (name1, param1), (name2, param2) in zip( - # sorted(params1.items()), sorted(params2.items()) - # ): - # # if len(weight_diffs) < 5: # Just check a few parameters - # max_diff = torch.max(torch.abs(param1 - param2)).item() - # mean_diff = torch.mean(torch.abs(param1 - param2)).item() - # weight_diffs.append((name1, name2, max_diff, mean_diff)) - # logger.info(f"Parameter: {name1} vs {name2}") - # logger.info(f" Max diff: {max_diff}, Mean diff: {mean_diff}") - + for name1, param1 in sorted(params1.items()): + name2 = name1 + skip = False + for param_name, weight_name, shard_id in model2.config.arch_config.stacked_params_mapping: + if weight_name not in name1: + skip = True + # stacked params are more troublesome + if skip: + continue + param2 = params2[name2] + param2 = param2.to_local().to(device) if isinstance(param2, DTensor) else param2.to(device) + assert_close(param1, param2, atol=1e-4, rtol=1e-4) + gc.collect() + torch.cuda.empty_cache() # Load tokenizer tokenizer, _ = load_tokenizer(tokenizer_type="clipL", tokenizer_path=args.model_path, diff --git a/fastvideo/v1/tests/encoders/test_llama_encoder.py b/fastvideo/v1/tests/encoders/test_llama_encoder.py index 9848d8588..488f3eddc 100644 --- a/fastvideo/v1/tests/encoders/test_llama_encoder.py +++ b/fastvideo/v1/tests/encoders/test_llama_encoder.py @@ -5,7 +5,7 @@ import pytest import torch from transformers import AutoConfig - +import gc from fastvideo.models.hunyuan.text_encoder import (load_text_encoder, load_tokenizer) from fastvideo.v1.configs.pipelines import PipelineConfig @@ -15,7 +15,8 @@ from fastvideo.v1.models.loader.component_loader import TextEncoderLoader from fastvideo.v1.utils import maybe_download_model from fastvideo.v1.configs.models.encoders import LlamaConfig - +from torch.distributed.tensor import DTensor +from torch.testing import assert_close logger = init_logger(__name__) os.environ["MASTER_ADDR"] = "localhost" @@ -62,7 +63,6 @@ def test_llama_encoder(): # Convert to float16 and move to device # model2 = model2.to(torch.float16) - model2 = model2.to(device) model2.eval() # Sanity check weights between the two models @@ -77,34 +77,28 @@ def test_llama_encoder(): # Compare a few key parameters weight_diffs = [] # check if embed_tokens are the same - print(model1.embed_tokens.weight.shape, model2.embed_tokens.weight.shape) + device = model1.embed_tokens.weight.device assert torch.allclose(model1.embed_tokens.weight, - model2.embed_tokens.weight) + model2.embed_tokens.weight.to_local().to(device) if isinstance(model2.embed_tokens.weight, DTensor) else model2.embed_tokens.weight.to(device)) weights = [ "layers.{}.input_layernorm.weight", "layers.{}.post_attention_layernorm.weight" ] - # for (name1, param1), (name2, param2) in zip( - # sorted(params1.items()), sorted(params2.items()) - # ): - for layer_idx in range(hf_config.num_hidden_layers): - for w in weights: - name1 = w.format(layer_idx) - name2 = w.format(layer_idx) - p1 = params1[name1] - p2 = params2[name2] - # print(type(p2)) - if "gate_up" in name2: - # print("skipping gate_up") - continue - try: - # logger.info(f"Parameter: {name1} vs {name2}") - max_diff = torch.max(torch.abs(p1 - p2)).item() - mean_diff = torch.mean(torch.abs(p1 - p2)).item() - weight_diffs.append((name1, name2, max_diff, mean_diff)) - # logger.info(f" Max diff: {max_diff}, Mean diff: {mean_diff}") - except Exception as e: - logger.info("Error comparing %s and %s: %s", name1, name2, e) + + for name1, param1 in sorted(params1.items()): + name2 = name1 + skip = False + for param_name, weight_name, shard_id in model2.config.arch_config.stacked_params_mapping: + if weight_name not in name1: + skip = True + # stacked params are more troublesome + if skip: + continue + param2 = params2[name2] + param2 = param2.to_local().to(device) if isinstance(param2, DTensor) else param2.to(device) + assert_close(param1, param2, atol=1e-4, rtol=1e-4) + gc.collect() + torch.cuda.empty_cache() tokenizer, _ = load_tokenizer(tokenizer_type="llm", tokenizer_path=TOKENIZER_PATH, diff --git a/fastvideo/v1/tests/encoders/test_t5_encoder.py b/fastvideo/v1/tests/encoders/test_t5_encoder.py index 9ff3c4c8a..8c9a616f9 100644 --- a/fastvideo/v1/tests/encoders/test_t5_encoder.py +++ b/fastvideo/v1/tests/encoders/test_t5_encoder.py @@ -4,6 +4,8 @@ import numpy as np import pytest import torch +from torch.distributed.tensor import DTensor +from torch.testing import assert_close from transformers import AutoConfig, AutoTokenizer, UMT5EncoderModel from fastvideo.v1.configs.pipelines import PipelineConfig @@ -41,13 +43,13 @@ def test_t5_encoder(): tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) - args = FastVideoArgs(model_path=TEXT_ENCODER_PATH, pipeline_config=PipelineConfig(text_encoder_configs=(T5Config(),), text_encoder_precisions=(precision_str,))) + args = FastVideoArgs(model_path=TEXT_ENCODER_PATH, + pipeline_config=PipelineConfig(text_encoder_configs=(T5Config(),), + text_encoder_precisions=(precision_str,)), + pin_cpu_memory=False) loader = TextEncoderLoader() model2 = loader.load(TEXT_ENCODER_PATH, "", args) - - # Convert to float16 and move to device - # model2 = model2.to(precision) - model2 = model2.to(device) + model2 = model2.to(precision) model2.eval() # Sanity check weights between the two models @@ -64,23 +66,17 @@ def test_t5_encoder(): weights = ["encoder.block.{}.layer.0.layer_norm.weight", "encoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight", \ "encoder.block.{}.layer.0.SelfAttention.o.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_0.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_1.weight",\ "encoder.block.{}.layer.1.DenseReluDense.wo.weight", \ - "encoder.block.{}.layer.1.layer_norm.weight", "encoder.final_layer_norm.weight", "shared.weight"] + "encoder.block.{}.layer.1.layer_norm.weight", "encoder.final_layer_norm.weight"] + for idx in range(hf_config.num_hidden_layers): for w in weights: name1 = w.format(idx) name2 = w.format(idx) p1 = params1[name1] p2 = params2[name2] - assert p1.dtype == p2.dtype - try: - logger.info("Parameter: %s vs %s", name1, name2) - max_diff = torch.max(torch.abs(p1 - p2)).item() - mean_diff = torch.mean(torch.abs(p1 - p2)).item() - weight_diffs.append((name1, name2, max_diff, mean_diff)) - logger.info(" Max diff: %s, Mean diff: %s", max_diff, - mean_diff) - except Exception as e: - logger.info("Error comparing %s and %s: %s", name1, name2, e) + p2 = (p2.to_local() if isinstance(p2, DTensor) else p2).to(p1) + assert_close(p1, p2, atol=1e-4, rtol=1e-4) + # Test with some sample prompts prompts = [ diff --git a/fastvideo/v1/tests/transformers/test_hunyuanvideo.py b/fastvideo/v1/tests/transformers/test_hunyuanvideo.py index 24e08cdb3..73eb39d47 100644 --- a/fastvideo/v1/tests/transformers/test_hunyuanvideo.py +++ b/fastvideo/v1/tests/transformers/test_hunyuanvideo.py @@ -80,7 +80,10 @@ def test_hunyuanvideo_distributed(): # Initialize with identical weights model = initialize_identical_weights(model, seed=42) - shard_model(model, cpu_offload=False, reshard_after_forward=True) + shard_model(model, cpu_offload=True, + reshard_after_forward=True, + fsdp_shard_conditions=model._fsdp_shard_conditions + ) for n, p in chain(model.named_parameters(), model.named_buffers()): if p.is_meta: raise RuntimeError( diff --git a/fastvideo/v1/training/training_pipeline.py b/fastvideo/v1/training/training_pipeline.py index 901487d61..a57c1d8bc 100644 --- a/fastvideo/v1/training/training_pipeline.py +++ b/fastvideo/v1/training/training_pipeline.py @@ -24,8 +24,9 @@ from fastvideo.v1.dataset import build_parquet_map_style_dataloader from fastvideo.v1.dataset.dataloader.schema import ( pyarrow_schema_t2v, pyarrow_schema_t2v_validation) -from fastvideo.v1.distributed import (cleanup_dist_env_and_memory, get_sp_group, - get_torch_device, get_world_group) +from fastvideo.v1.distributed import (cleanup_dist_env_and_memory, + get_local_torch_device, get_sp_group, + get_world_group) from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger @@ -65,8 +66,8 @@ def set_schemas(self) -> None: def initialize_training_pipeline(self, training_args: TrainingArgs): logger.info("Initializing training pipeline...") + self.device = get_local_torch_device() self.training_args = training_args - self.device = get_torch_device() world_group = get_world_group() self.world_size = world_group.world_size self.global_rank = world_group.rank @@ -176,12 +177,12 @@ def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: encoder_attention_mask = batch['text_attention_mask'] infos = batch['info_list'] - training_batch.latents = latents.to(get_torch_device(), + training_batch.latents = latents.to(get_local_torch_device(), dtype=torch.bfloat16) training_batch.encoder_hidden_states = encoder_hidden_states.to( - get_torch_device(), dtype=torch.bfloat16) + get_local_torch_device(), dtype=torch.bfloat16) training_batch.encoder_attention_mask = encoder_attention_mask.to( - get_torch_device(), dtype=torch.bfloat16) + get_local_torch_device(), dtype=torch.bfloat16) training_batch.infos = infos return training_batch @@ -274,7 +275,7 @@ def _build_input_kwargs(self, "encoder_hidden_states": training_batch.encoder_hidden_states, "timestep": - training_batch.timesteps.to(get_torch_device(), + training_batch.timesteps.to(get_local_torch_device(), dtype=torch.bfloat16), "encoder_attention_mask": training_batch.encoder_attention_mask, @@ -551,8 +552,9 @@ def _prepare_validation_inputs( prompt_embeds = validation_batch['text_embedding'] prompt_attention_mask = validation_batch['text_attention_mask'] - prompt_embeds = prompt_embeds.to(get_torch_device()) - prompt_attention_mask = prompt_attention_mask.to(get_torch_device()) + prompt_embeds = prompt_embeds.to(get_local_torch_device()) + prompt_attention_mask = prompt_attention_mask.to( + get_local_torch_device()) # Calculate sizes latents_size = [(sampling_param.num_frames - 1) // 4 + 1, diff --git a/fastvideo/v1/training/wan_i2v_training_pipeline.py b/fastvideo/v1/training/wan_i2v_training_pipeline.py index b58cfa26a..1c5475e75 100644 --- a/fastvideo/v1/training/wan_i2v_training_pipeline.py +++ b/fastvideo/v1/training/wan_i2v_training_pipeline.py @@ -8,7 +8,7 @@ from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset.dataloader.schema import ( pyarrow_schema_i2v, pyarrow_schema_i2v_validation) -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.schedulers.scheduling_flow_unipc_multistep import ( @@ -85,15 +85,17 @@ def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: pil_image = batch['pil_image'] infos = batch['info_list'] - training_batch.latents = latents.to(get_torch_device(), + training_batch.latents = latents.to(get_local_torch_device(), dtype=torch.bfloat16) training_batch.encoder_hidden_states = encoder_hidden_states.to( - get_torch_device(), dtype=torch.bfloat16) + get_local_torch_device(), dtype=torch.bfloat16) training_batch.encoder_attention_mask = encoder_attention_mask.to( - get_torch_device(), dtype=torch.bfloat16) - training_batch.preprocessed_image = pil_image.to(get_torch_device()) - training_batch.image_embeds = clip_features.to(get_torch_device()) - training_batch.image_latents = image_latents.to(get_torch_device()) + get_local_torch_device(), dtype=torch.bfloat16) + training_batch.preprocessed_image = pil_image.to( + get_local_torch_device()) + training_batch.image_embeds = clip_features.to(get_local_torch_device()) + training_batch.image_latents = image_latents.to( + get_local_torch_device()) training_batch.infos = infos return training_batch @@ -112,8 +114,8 @@ def _prepare_dit_inputs(self, training_batch = super()._prepare_dit_inputs(training_batch) assert isinstance(training_batch.image_latents, torch.Tensor) - image_latents = training_batch.image_latents.to(get_torch_device(), - dtype=torch.bfloat16) + image_latents = training_batch.image_latents.to( + get_local_torch_device(), dtype=torch.bfloat16) training_batch.noisy_model_input = torch.cat( [training_batch.noisy_model_input, image_latents], dim=1) @@ -132,7 +134,8 @@ def _build_input_kwargs(self, # Image Embeds for conditioning image_embeds = training_batch.image_embeds assert torch.isnan(image_embeds).sum() == 0 - image_embeds = image_embeds.to(get_torch_device(), dtype=torch.bfloat16) + image_embeds = image_embeds.to(get_local_torch_device(), + dtype=torch.bfloat16) encoder_hidden_states_image = image_embeds # NOTE: noisy_model_input already contains concatenated image_latents from _prepare_dit_inputs @@ -142,7 +145,7 @@ def _build_input_kwargs(self, "encoder_hidden_states": training_batch.encoder_hidden_states, "timestep": - training_batch.timesteps.to(get_torch_device(), + training_batch.timesteps.to(get_local_torch_device(), dtype=torch.bfloat16), "encoder_attention_mask": training_batch.encoder_attention_mask, @@ -166,9 +169,9 @@ def _prepare_validation_inputs( infos = validation_batch['info_list'] prompt = infos[0]['prompt'] - prompt_embeds = embeddings.to(get_torch_device()) - prompt_attention_mask = masks.to(get_torch_device()) - clip_features = clip_features.to(get_torch_device()) + prompt_embeds = embeddings.to(get_local_torch_device()) + prompt_attention_mask = masks.to(get_local_torch_device()) + clip_features = clip_features.to(get_local_torch_device()) # Calculate sizes latents_size = [(sampling_param.num_frames - 1) // 4 + 1,