Skip to content

Commit cf9be17

Browse files
authored
Merge pull request #129 from a-r-r-o-w/condition-precomputation
Precomputation of conditions and latents
2 parents 223add1 + 2858346 commit cf9be17

File tree

10 files changed

+697
-124
lines changed

10 files changed

+697
-124
lines changed

README.md

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,60 @@ video = pipe("<my-awesome-prompt>").frames[0]
143143
export_to_video(video, "output.mp4", fps=8)
144144
```
145145
146+
### Memory Usage
147+
148+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:
149+
150+
```
151+
Training configuration: {
152+
"trainable parameters": 117440512,
153+
"total samples": 69,
154+
"train epochs": 1,
155+
"train steps": 10,
156+
"batches per device": 1,
157+
"total batches observed per epoch": 69,
158+
"train batch size": 1,
159+
"gradient accumulation steps": 1
160+
}
161+
```
162+
163+
| stage | memory_allocated | max_memory_reserved |
164+
|:-----------------------:|:----------------:|:-------------------:|
165+
| before training start | 13.486 | 13.879 |
166+
| before validation start | 14.146 | 17.623 |
167+
| after validation end | 14.146 | 17.623 |
168+
| after epoch 1 | 14.146 | 17.623 |
169+
| after training end | 4.461 | 17.623 |
170+
171+
Note: requires about `18` GB of VRAM without precomputation.
172+
173+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **with precomputation**:
174+
175+
```
176+
Training configuration: {
177+
"trainable parameters": 117440512,
178+
"total samples": 1,
179+
"train epochs": 10,
180+
"train steps": 10,
181+
"batches per device": 1,
182+
"total batches observed per epoch": 1,
183+
"train batch size": 1,
184+
"gradient accumulation steps": 1
185+
}
186+
```
187+
188+
| stage | memory_allocated | max_memory_reserved |
189+
|:-----------------------------:|:----------------:|:-------------------:|
190+
| after precomputing conditions | 8.88 | 8.920 |
191+
| after precomputing latents | 9.684 | 11.613 |
192+
| before training start | 3.809 | 10.010 |
193+
| after epoch 1 | 4.26 | 10.916 |
194+
| before validation start | 4.26 | 10.916 |
195+
| after validation end | 13.924 | 17.262 |
196+
| after training end | 4.26 | 14.314 |
197+
198+
Note: requires about `17.5` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to `11` GB.
199+
146200
</details>
147201
148202
<details>
@@ -169,8 +223,7 @@ OUTPUT_DIR="/path/to/models/hunyuan-video/hunyuan-video-loras/hunyuan-video_caki
169223
170224
# Model arguments
171225
model_cmd="--model_name hunyuan_video \
172-
--pretrained_model_name_or_path tencent/HunyuanVideo
173-
--revision refs/pr/18"
226+
--pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo"
174227
175228
# Dataset arguments
176229
dataset_cmd="--data_root $DATA_ROOT \
@@ -252,7 +305,7 @@ import torch
252305
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
253306
from diffusers.utils import export_to_video
254307

255-
model_id = "tencent/HunyuanVideo"
308+
model_id = "hunyuanvideo-community/HunyuanVideo"
256309
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
257310
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
258311
)
@@ -272,10 +325,70 @@ output = pipe(
272325
export_to_video(output, "output.mp4", fps=15)
273326
```
274327

328+
### Memory Usage
329+
330+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:
331+
332+
```
333+
Training configuration: {
334+
"trainable parameters": 163577856,
335+
"total samples": 69,
336+
"train epochs": 1,
337+
"train steps": 10,
338+
"batches per device": 1,
339+
"total batches observed per epoch": 69,
340+
"train batch size": 1,
341+
"gradient accumulation steps": 1
342+
}
343+
```
344+
345+
| stage | memory_allocated | max_memory_reserved |
346+
|:-----------------------:|:----------------:|:-------------------:|
347+
| before training start | 38.889 | 39.020 |
348+
| before validation start | 39.747 | 56.266 |
349+
| after validation end | 39.748 | 58.385 |
350+
| after epoch 1 | 39.748 | 40.910 |
351+
| after training end | 25.288 | 40.910 |
352+
353+
Note: requires about `59` GB of VRAM without precomputation.
354+
355+
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **with precomputation**:
356+
357+
```
358+
Training configuration: {
359+
"trainable parameters": 163577856,
360+
"total samples": 1,
361+
"train epochs": 10,
362+
"train steps": 10,
363+
"batches per device": 1,
364+
"total batches observed per epoch": 1,
365+
"train batch size": 1,
366+
"gradient accumulation steps": 1
367+
}
368+
```
369+
370+
| stage | memory_allocated | max_memory_reserved |
371+
|:-----------------------------:|:----------------:|:-------------------:|
372+
| after precomputing conditions | 14.232 | 14.461 |
373+
| after precomputing latents | 14.717 | 17.244 |
374+
| before training start | 24.195 | 26.039 |
375+
| after epoch 1 | 24.83 | 42.387 |
376+
| before validation start | 24.842 | 42.387 |
377+
| after validation end | 39.558 | 46.947 |
378+
| after training end | 24.842 | 41.039 |
379+
380+
Note: requires about `47` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to about `42` GB.
381+
275382
</details>
276383

277384
If you would like to use a custom dataset, refer to the dataset preparation guide [here](./assets/dataset.md).
278385

386+
> [!NOTE]
387+
> To lower memory requirements:
388+
> - Pass `--precompute_conditions` when launching training.
389+
> - Pass `--gradient_checkpointing` when launching training.
390+
> - Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
391+
279392
## Memory requirements
280393

281394
<table align="center">

finetrainers/args.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import argparse
22
from typing import Any, Dict, List, Optional, Tuple
33

4+
import torch
5+
46
from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS
57

68

@@ -20,6 +22,11 @@ class Args:
2022
revision: Optional[str] = None
2123
variant: Optional[str] = None
2224
cache_dir: Optional[str] = None
25+
text_encoder_dtype: torch.dtype = torch.bfloat16
26+
text_encoder_2_dtype: torch.dtype = torch.bfloat16
27+
text_encoder_3_dtype: torch.dtype = torch.bfloat16
28+
transformer_dtype: torch.dtype = torch.bfloat16
29+
vae_dtype: torch.dtype = torch.bfloat16
2330

2431
# Dataset arguments
2532
data_root: str = None
@@ -32,6 +39,7 @@ class Args:
3239
video_reshape_mode: Optional[str] = None
3340
caption_dropout_p: float = 0.00
3441
caption_dropout_technique: str = "empty"
42+
precompute_conditions: bool = False
3543

3644
# Dataloader arguments
3745
dataloader_num_workers: int = 0
@@ -113,6 +121,11 @@ def to_dict(self) -> Dict[str, Any]:
113121
"revision": self.revision,
114122
"variant": self.variant,
115123
"cache_dir": self.cache_dir,
124+
"text_encoder_dtype": self.text_encoder_dtype,
125+
"text_encoder_2_dtype": self.text_encoder_2_dtype,
126+
"text_encoder_3_dtype": self.text_encoder_3_dtype,
127+
"transformer_dtype": self.transformer_dtype,
128+
"vae_dtype": self.vae_dtype,
116129
},
117130
"dataset_arguments": {
118131
"data_root": self.data_root,
@@ -124,6 +137,8 @@ def to_dict(self) -> Dict[str, Any]:
124137
"video_resolution_buckets": self.video_resolution_buckets,
125138
"video_reshape_mode": self.video_reshape_mode,
126139
"caption_dropout_p": self.caption_dropout_p,
140+
"caption_dropout_technique": self.caption_dropout_technique,
141+
"precompute_conditions": self.precompute_conditions,
127142
},
128143
"dataloader_arguments": {
129144
"dataloader_num_workers": self.dataloader_num_workers,
@@ -234,6 +249,11 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
234249
default=None,
235250
help="The directory where the downloaded models and datasets will be stored.",
236251
)
252+
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.")
253+
parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.")
254+
parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
255+
parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
256+
parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")
237257

238258

239259
def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
@@ -317,6 +337,11 @@ def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int
317337
choices=["empty", "zero"],
318338
help="Technique to use for caption dropout.",
319339
)
340+
parser.add_argument(
341+
"--precompute_conditions",
342+
action="store_true",
343+
help="Whether or not to precompute the conditionings for the model.",
344+
)
320345

321346

322347
def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None:
@@ -645,6 +670,13 @@ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
645670
)
646671

647672

673+
_DTYPE_MAP = {
674+
"bf16": torch.bfloat16,
675+
"fp16": torch.float16,
676+
"fp32": torch.float32,
677+
}
678+
679+
648680
def _map_to_args_type(args: Dict[str, Any]) -> Args:
649681
result_args = Args()
650682

@@ -654,6 +686,11 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
654686
result_args.revision = args.revision
655687
result_args.variant = args.variant
656688
result_args.cache_dir = args.cache_dir
689+
result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
690+
result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
691+
result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
692+
result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
693+
result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]
657694

658695
# Dataset arguments
659696
if args.data_root is None and args.dataset_file is None:
@@ -668,6 +705,8 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
668705
result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS
669706
result_args.video_reshape_mode = args.video_reshape_mode
670707
result_args.caption_dropout_p = args.caption_dropout_p
708+
result_args.caption_dropout_technique = args.caption_dropout_technique
709+
result_args.precompute_conditions = args.precompute_conditions
671710

672711
# Dataloader arguments
673712
result_args.dataloader_num_workers = args.dataloader_num_workers

finetrainers/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")
2121

22+
PRECOMPUTED_DIR_NAME = "precomputed"
23+
PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions"
24+
PRECOMPUTED_LATENTS_DIR_NAME = "latents"
25+
2226
MODEL_DESCRIPTION = r"""
2327
\# {model_id} {training_type} finetune
2428

finetrainers/dataset.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import random
23
from pathlib import Path
34
from typing import Any, Dict, List, Optional, Tuple
@@ -19,6 +20,9 @@
1920

2021
decord.bridge.set_bridge("torch")
2122

23+
from .constants import PRECOMPUTED_DIR_NAME, PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME
24+
25+
2226
logger = get_logger(__name__)
2327

2428

@@ -257,6 +261,32 @@ def _find_nearest_resolution(self, height, width):
257261
return nearest_res[1], nearest_res[2]
258262

259263

264+
class PrecomputedDataset(Dataset):
265+
def __init__(self, data_root: str) -> None:
266+
super().__init__()
267+
268+
self.data_root = Path(data_root)
269+
270+
self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
271+
self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME
272+
273+
self.latent_conditions = sorted(os.listdir(self.latents_path))
274+
self.text_conditions = sorted(os.listdir(self.conditions_path))
275+
276+
assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match"
277+
278+
def __len__(self) -> int:
279+
return len(self.latent_conditions)
280+
281+
def __getitem__(self, index: int) -> Dict[str, Any]:
282+
conditions = {}
283+
latent_path = self.latents_path / self.latent_conditions[index]
284+
condition_path = self.conditions_path / self.text_conditions[index]
285+
conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True)
286+
conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True)
287+
return conditions
288+
289+
260290
class BucketSampler(Sampler):
261291
r"""
262292
PyTorch Sampler that groups 3D data by height, width and frames.

0 commit comments

Comments
 (0)