Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fake*) FP8 training support #184

Merged
merged 24 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0668624
remove mixed_precision
a-r-r-o-w Jan 5, 2025
472d996
update
a-r-r-o-w Jan 5, 2025
14e7ed6
make style
a-r-r-o-w Jan 5, 2025
db44554
update
a-r-r-o-w Jan 6, 2025
93a82a0
Merge branch 'main' into layerwise-fp8-upcasting
a-r-r-o-w Jan 6, 2025
fe2ed4a
better defaults for experimenting
a-r-r-o-w Jan 7, 2025
15bf3cb
Merge branch 'main' into layerwise-fp8-upcasting
a-r-r-o-w Jan 7, 2025
7d96850
fix train continuation after validation error
a-r-r-o-w Jan 8, 2025
5bad4d8
Merge branch 'main' into layerwise-fp8-upcasting
a-r-r-o-w Jan 8, 2025
8735e69
Merge branch 'main' into layerwise-fp8-upcasting
a-r-r-o-w Jan 9, 2025
778969f
Merge branch 'main' into layerwise-fp8-upcasting
a-r-r-o-w Jan 11, 2025
b35ef44
Merge branch 'main' into layerwise-fp8-upcasting
a-r-r-o-w Jan 11, 2025
c3916c2
update READMEs
a-r-r-o-w Jan 13, 2025
59db964
Merge branch 'main' into layerwise-fp8-upcasting
a-r-r-o-w Jan 13, 2025
927f2bc
Merge branch 'main' into layerwise-fp8-upcasting
a-r-r-o-w Jan 13, 2025
67bd30f
remove granularity
a-r-r-o-w Jan 14, 2025
7c351cf
update hook implementation to latest diffusers)
a-r-r-o-w Jan 14, 2025
15afe73
update
a-r-r-o-w Jan 14, 2025
ef5a274
Merge branch 'main' into layerwise-fp8-upcasting
a-r-r-o-w Jan 14, 2025
f0c45fb
update
a-r-r-o-w Jan 14, 2025
ca776f4
remove unused patches
a-r-r-o-w Jan 14, 2025
605e94a
remove mixed precision in tests
a-r-r-o-w Jan 14, 2025
b0705b6
add changes lost in merge conflict resolution
a-r-r-o-w Jan 14, 2025
42e707d
update README date
a-r-r-o-w Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ FineTrainers is a work-in-progress library to support (accessible) training of v

## News

- 🔥 **2024-01-15**: Support for naive FP8 weight-casting training added! This allows training HunyuanVideo in under 24 GB upto specific resolutions.
- 🔥 **2024-01-13**: Support for T2V full-finetuning added! Thanks to @ArEnSc for taking up the initiative!
- 🔥 **2024-01-03**: Support for T2V LoRA finetuning of [CogVideoX](https://huggingface.co/docs/diffusers/main/api/pipelines/cogvideox) added!
- 🔥 **2024-12-20**: Support for T2V LoRA finetuning of [Hunyuan Video](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video) added! We would like to thank @SHYuanBest for his work on a training script [here](https://github.com/huggingface/diffusers/pull/10254).
Expand Down Expand Up @@ -83,7 +84,6 @@ diffusion_cmd="--flow_weighting_scheme logit_normal"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
--mixed_precision bf16 \
--batch_size 1 \
--train_steps 3000 \
--rank 128 \
Expand Down Expand Up @@ -140,14 +140,14 @@ For inference, refer [here](./docs/training/ltx_video.md#inference). For docs re

| **Model Name** | **Tasks** | **Min. LoRA VRAM<sup>*</sup>** | **Min. Full Finetuning VRAM<sup>^</sup>** |
|:------------------------------------------------:|:-------------:|:----------------------------------:|:---------------------------------------------:|
| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 11 GB | 21 GB |
| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 42 GB | OOM |
| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 21 GB | 53 GB |
| [LTX-Video](./docs/training/ltx_video.md) | Text-to-Video | 5 GB | 21 GB |
| [HunyuanVideo](./docs/training/hunyuan_video.md) | Text-to-Video | 32 GB | OOM |
| [CogVideoX-5b](./docs/training/cogvideox.md) | Text-to-Video | 18 GB | 53 GB |

</div>

<sub><sup>*</sup>Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using fp8 weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).</sub><br/>
<sub><sup>^</sup>Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using bf16 weights & gradient checkpointing.</sub>
<sub><sup>*</sup>Noted for training-only, no validation, at resolution `49x512x768`, rank 128, with pre-computation, using **FP8** weights & gradient checkpointing. Pre-computation of conditions and latents may require higher limits (but typically under 16 GB).</sub><br/>
<sub><sup>^</sup>Noted for training-only, no validation, at resolution `49x512x768`, with pre-computation, using **BF16** weights & gradient checkpointing.</sub>

If you would like to use a custom dataset, refer to the dataset preparation guide [here](./docs/dataset/README.md).

Expand Down
2 changes: 1 addition & 1 deletion accelerate_configs/compiled_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ enable_cpu_affinity: false
gpu_ids: '3'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
Expand Down
2 changes: 1 addition & 1 deletion accelerate_configs/uncompiled_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ enable_cpu_affinity: false
gpu_ids: '3'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
Expand Down
7 changes: 6 additions & 1 deletion docs/training/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ dataloader_cmd="--dataloader_num_workers 4"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
--mixed_precision bf16 \
--batch_size 1 \
--precompute_conditions \
--train_steps 1000 \
Expand Down Expand Up @@ -88,6 +87,12 @@ echo -ne "-------------------- Finished executing script --------------------\n\

### LoRA

<!-- TODO(aryan): Update these numbers for 49x512x768 -->

> [!NOTE]
>
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x480x720` resolutions, **with precomputation**:

```
Expand Down
5 changes: 4 additions & 1 deletion docs/training/hunyuan_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ diffusion_cmd=""
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
--mixed_precision bf16 \
--batch_size 1 \
--train_steps 500 \
--rank 128 \
Expand Down Expand Up @@ -91,6 +90,10 @@ echo -ne "-------------------- Finished executing script --------------------\n\

### LoRA

> [!NOTE]
>
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:

```
Expand Down
5 changes: 4 additions & 1 deletion docs/training/ltx_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ diffusion_cmd="--flow_weighting_scheme logit_normal"
# Training arguments
training_cmd="--training_type lora \
--seed 42 \
--mixed_precision bf16 \
--batch_size 1 \
--train_steps 3000 \
--rank 128 \
Expand Down Expand Up @@ -90,6 +89,10 @@ echo -ne "-------------------- Finished executing script --------------------\n\

### LoRA

> [!NOTE]
>
> The below measurements are done in `torch.bfloat16` precision. Memory usage can further be reduce by passing `--layerwise_upcasting_modules transformer` to the training script. This will cast the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`, which halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`).

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:

```
Expand Down
11 changes: 7 additions & 4 deletions docs/training/optimization.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Memory optimizations

To lower memory requirements during training:

- `--precompute_conditions`: this precomputes the conditions and latents, and loads them as required during training, which saves a significant amount of time and memory.
- `--gradient_checkpointing`: this saves memory by recomputing activations during the backward pass.
- `--layerwise_upcasting_modules transformer`: naively casts the model weights to `torch.float8_e4m3fn` or `torch.float8_e5m2`. This halves the memory requirement for model weights. Computation is performed in the dtype set by `--transformer_dtype` (which defaults to `bf16`)
- `--use_8bit_bnb`: this is only applicable to Adam and AdamW optimizers, and makes use of 8-bit precision to store optimizer states.
- Use a DeepSpeed config to launch training (refer to [`accelerate_configs/deepspeed.yaml`](./accelerate_configs/deepspeed.yaml) as an example).
- Pass `--precompute_conditions` when launching training.
- Pass `--gradient_checkpointing` when launching training.
- Pass `--use_8bit_bnb` when launching training. Note that this is only applicable to Adam and AdamW optimizers.
- Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.

We will continue to add more features that help to reduce memory consumption.
We will continue to add more features that help to reduce memory consumption.
76 changes: 58 additions & 18 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ class Args:
Data type for the transformer model.
vae_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
Data type for the VAE model.
layerwise_upcasting_modules (`List[str]`, defaults to `[]`):
Modules that should have fp8 storage weights but higher precision computation. Choose between ['transformer'].
layerwise_upcasting_storage_dtype (`torch.dtype`, defaults to `float8_e4m3fn`):
Data type for the layerwise upcasting storage. Choose between ['float8_e4m3fn', 'float8_e5m2'].
layerwise_upcasting_skip_modules_pattern (`List[str]`, defaults to `["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"]`):
Modules to skip for layerwise upcasting. Layers such as normalization and modulation, when casted to fp8 precision
naively (as done in layerwise upcasting), can lead to poorer training and inference quality. We skip these layers
by default, and recommend adding more layers to the default list based on the model architecture.

DATASET ARGUMENTS
-----------------
Expand Down Expand Up @@ -126,8 +134,6 @@ class Args:
Type of training to perform. Choose between ['lora'].
seed (`int`, defaults to `42`):
A seed for reproducible training.
mixed_precision (`str`, defaults to `None`):
Whether to use mixed precision. Choose between ['no', 'fp8', 'fp16', 'bf16'].
batch_size (`int`, defaults to `1`):
Per-device batch size.
train_epochs (`int`, defaults to `1`):
Expand Down Expand Up @@ -243,6 +249,18 @@ class Args:
text_encoder_3_dtype: torch.dtype = torch.bfloat16
transformer_dtype: torch.dtype = torch.bfloat16
vae_dtype: torch.dtype = torch.bfloat16
layerwise_upcasting_modules: List[str] = []
layerwise_upcasting_storage_dtype: torch.dtype = torch.float8_e4m3fn
layerwise_upcasting_skip_modules_pattern: List[str] = [
"patch_embed",
"pos_embed",
"x_embedder",
"context_embedder",
"time_embed",
"^proj_in$",
"^proj_out$",
"norm",
]

# Dataset arguments
data_root: str = None
Expand Down Expand Up @@ -277,9 +295,6 @@ class Args:
# Training arguments
training_type: str = None
seed: int = 42
mixed_precision: str = (
None # TODO: consider removing later https://github.com/a-r-r-o-w/finetrainers/pull/139#discussion_r1897438414
)
batch_size: int = 1
train_epochs: int = 1
train_steps: int = None
Expand Down Expand Up @@ -347,6 +362,9 @@ def to_dict(self) -> Dict[str, Any]:
"text_encoder_3_dtype": self.text_encoder_3_dtype,
"transformer_dtype": self.transformer_dtype,
"vae_dtype": self.vae_dtype,
"layerwise_upcasting_modules": self.layerwise_upcasting_modules,
"layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
"layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
},
"dataset_arguments": {
"data_root": self.data_root,
Expand Down Expand Up @@ -381,7 +399,6 @@ def to_dict(self) -> Dict[str, Any]:
"training_arguments": {
"training_type": self.training_type,
"seed": self.seed,
"mixed_precision": self.mixed_precision,
"batch_size": self.batch_size,
"train_epochs": self.train_epochs,
"train_steps": self.train_steps,
Expand Down Expand Up @@ -464,6 +481,7 @@ def parse_arguments() -> Args:


def validate_args(args: Args):
_validated_model_args(args)
_validate_training_args(args)
_validate_validation_args(args)

Expand Down Expand Up @@ -506,6 +524,28 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")
parser.add_argument(
"--layerwise_upcasting_modules",
type=str,
default=[],
nargs="+",
choices=["transformer"],
help="Modules that should have fp8 storage weights but higher precision computation.",
)
parser.add_argument(
"--layerwise_upcasting_storage_dtype",
type=str,
default="float8_e4m3fn",
choices=["float8_e4m3fn", "float8_e5m2"],
help="Data type for the layerwise upcasting storage.",
)
parser.add_argument(
"--layerwise_upcasting_skip_modules_pattern",
type=str,
default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
nargs="+",
help="Modules to skip for layerwise upcasting.",
)


def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
Expand Down Expand Up @@ -688,16 +728,6 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
help="Type of training to perform. Choose between ['lora', 'full-finetune']",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp8", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Defaults to the value of accelerate config of the current system or the "
"flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--batch_size",
type=int,
Expand Down Expand Up @@ -979,8 +1009,9 @@ def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
"float8_e4m3fn": torch.float8_e4m3fn,
"float8_e5m2": torch.float8_e5m2,
}
_INVERSE_DTYPE_MAP = {v: k for k, v in _DTYPE_MAP.items()}


def _map_to_args_type(args: Dict[str, Any]) -> Args:
Expand All @@ -997,6 +1028,9 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]
result_args.layerwise_upcasting_modules = args.layerwise_upcasting_modules
result_args.layerwise_upcasting_storage_dtype = _DTYPE_MAP[args.layerwise_upcasting_storage_dtype]
result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern

# Dataset arguments
if args.data_root is None and args.dataset_file is None:
Expand Down Expand Up @@ -1034,7 +1068,6 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
# Training arguments
result_args.training_type = args.training_type
result_args.seed = args.seed
result_args.mixed_precision = args.mixed_precision
result_args.batch_size = args.batch_size
result_args.train_epochs = args.train_epochs
result_args.train_steps = args.train_steps
Expand Down Expand Up @@ -1117,6 +1150,13 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
return result_args


def _validated_model_args(args: Args):
if args.training_type == "full-finetune":
assert (
"transformer" not in args.layerwise_upcasting_modules
), "Layerwise upcasting is not supported for full-finetune training"


def _validate_training_args(args: Args):
if args.training_type == "lora":
assert args.rank is not None, "Rank is required for LoRA training"
Expand Down
8 changes: 7 additions & 1 deletion finetrainers/cogvideox/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def initialize_pipeline(
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
is_training: bool = False,
**kwargs,
) -> CogVideoXPipeline:
component_name_pairs = [
Expand All @@ -81,9 +82,14 @@ def initialize_pipeline(

pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
pipe.vae = pipe.vae.to(dtype=vae_dtype)

# The transformer should already be in the correct dtype when training, so we don't need to cast it here.
# If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
# DDP optimizer step.
if not is_training:
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)

if enable_slicing:
pipe.vae.enable_slicing()
if enable_tiling:
Expand Down
1 change: 1 addition & 0 deletions finetrainers/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .layerwise_upcasting import apply_layerwise_upcasting
Loading