Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full Finetuning for LTX possibily extended to other models. #192

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

ArEnSc
Copy link
Contributor

@ArEnSc ArEnSc commented Jan 7, 2025

This pull request introduces the following changes:
Added a --supervised-finetuning flag "sft" for each model to control fine-tuning individually.
Refactored To-Do items related to:
Managing the list of models.
Enabling gradients in the model components.

Todo:
Run SFT for a few runs on one model to ensure stability

@@ -321,3 +321,18 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
"forward_pass": forward_pass,
"validation": validation,
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know if you two wanted to enable this through swapping out prepare parameters I thought it might have been over engineering to do that so I just made a copy of this config.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is perfect and the intended usage. This will eventually be refactored out into model specs to add some syntactic sugar and make the code easier to follow

@@ -455,7 +455,7 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
"--training_type",
type=str,
default=None,
help="Type of training to perform. Choose between ['lora']",
help="Type of training to perform. Choose between ['lora','finetune']",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this here should it be called full_finetune?

@ArEnSc
Copy link
Contributor Author

ArEnSc commented Jan 7, 2025

@a-r-r-o-w
@sayakpaul
Let me know what you think and what I could fix. I suspect you two had a plan of some kind this is my best guess at how you would want this to progress

Copy link
Owner

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! The PR is in good shape :)

Usually, when I add something, I try to make sure it is well tested with some basic settings that demonstrates that the training works as expected. With this, I also try to accompany it with one long training run (for loras, it is a 10000 step run).

Would like to do something similar here for all models added so far. Since CogVideoX has been supported in the legacy scripts, we know that SFT for that worked well, so we can skip it. But would definitely like to test LTX and Hunyuan before merging (I'm sure it works as expected but would like to do so anyway as documentation/proof that it works). That may take some time, but I hope that's okay with you.

Nothing much needs to be done - just a single 10000-50000 step run maybe with sensible default settings that we can recommend in the README. Happy to help with both LTX and Hunyuan, but if you have any logs from LTX training, it will be valuable to share your experiment results so far

Thanks again, and great work 🤗

Would maybe run make style too to auto-format the code changes to expected style

@@ -455,7 +455,7 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
"--training_type",
type=str,
default=None,
help="Type of training to perform. Choose between ['lora']",
help="Type of training to perform. Choose between ['lora','full_finetune']",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this sft in the same spirit as what CogVideoX folks use in their training scripts.

Eventually, I would like things like lora, lokr, loha, etc. to basically use the same interface for lora-type training. And sft (supervised finetuning) to mean no additional parameters and to train just the transformer/unet

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might I chime in and say "sft" sounds a bit incomplete to me. Because we can do LoRA for SFTs, too.

So, "full-finetune" is a more complete term, technically which will be less-confusing long-term.

Copy link
Owner

@a-r-r-o-w a-r-r-o-w Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that sounds good too. Since I'm a bit occupied on the FP8 PR, @sayakpaul would you like to take up a full finetuning run of maybe LTX?

Or Cog/Hunyuan as well. Up to you which model you'd like to try if you're going to try

In parallel, I can try any remaining model when I get some time. Need to curate a good dataset of something nice so we could release the finetuned checkpoints as well. LMK your thoughts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave that with me.

Let's settle on a common set of hyperparameters we want to use throughout this experiment?

training steps: 10000
batch size: 1, grad accum: 4 grad checkpointing
precompute_conditions

What else?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • lr 5e-6
  • lr schedule as constant_with_warmup
  • warmup steps as 2000 (20% of whatever we choose as training steps)
  • adamw weight_decay 0.01
  • resolutions will have to be decided based on dataset
  • validation every 500 steps with atleast 8 prompts

on 8xH100, this would effectively be a batch size of 32, so would maybe use 1024 to whatever higher limit you prefer videos. That will give us 32 batches per epoch, and ~312 epochs I think.

40000 training steps seems good overall (10000 actual steps but 4 gradient accumulation) but can do lower if we don't want to hijack the cluster for a long time.

LTX and Cog will be fast to train upto decent resolution, but it is Hunyuan I'm worried about. We don't have to do Cog as such because I had verified it back when we were using the legacy scripts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will tackle LTX first then.

@@ -321,3 +321,18 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
"forward_pass": forward_pass,
"validation": validation,
}

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is perfect and the intended usage. This will eventually be refactored out into model specs to add some syntactic sugar and make the code easier to follow

finetrainers/models.py Outdated Show resolved Hide resolved
@@ -455,7 +455,7 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
"--training_type",
type=str,
default=None,
help="Type of training to perform. Choose between ['lora']",
help="Type of training to perform. Choose between ['lora','sft']",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the back and fourth :(

As discussed with @a-r-r-o-w, let's switch SFT back to "full-finetune".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved~

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it's still SFT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArEnSc cc

@@ -455,7 +455,7 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
"--training_type",
type=str,
default=None,
help="Type of training to perform. Choose between ['lora']",
help="Type of training to perform. Choose between ['lora','sft']",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can happen in another PR but we could also provide some info to the users know that if rank and other LoRA related arguments are provided alongside training_type=="sft", they will be ignored.

WDYT?

@sayakpaul
Copy link
Collaborator

sayakpaul commented Jan 8, 2025

Seems like there are some existing problems.

My training command:

export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export FINETRAINERS_LOG_LEVEL=DEBUG

GPU_IDS="0,1,2,3,4,5,6,7"

DATA_ROOT="video-dataset-disney"
CAPTION_COLUMN="prompt.txt"
VIDEO_COLUMN="videos.txt"
OUTPUT_DIR="ltxv_disney"

ID_TOKEN="BW_STYLE"

# Model arguments
model_cmd="--model_name ltx_video \
  --pretrained_model_name_or_path Lightricks/LTX-Video"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
  --video_column $VIDEO_COLUMN \
  --caption_column $CAPTION_COLUMN \
  --id_token $ID_TOKEN \
  --video_resolution_buckets 49x512x768 \
  --caption_dropout_p 0.05"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 0"

# Diffusion arguments
diffusion_cmd="--flow_resolution_shifting"

# Training arguments
training_cmd="--training_type sft \
  --seed 42 \
  --mixed_precision bf16 \
  --precompute_conditions \
  --batch_size 1 \
  --train_steps 10000 \
  --gradient_accumulation_steps 4 \
  --gradient_checkpointing \
  --checkpointing_steps 2000 \
  --checkpointing_limit 2 \
  --enable_slicing \
  --enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
  --lr 5e-6 \
  --lr_scheduler constant_with_warmup \
  --lr_warmup_steps 2000 \
  --lr_num_cycles 1 \
  --weight_decay 1e-2"

# Validation arguments
validation_cmd="--validation_prompts \"$ID_TOKEN A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@49x512x768:::$ID_TOKEN A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage@@@49x512x768\"
   --validation_steps 500 \
   --num_validation_videos 1"

# Miscellaneous arguments
miscellaneous_cmd="--tracker_name finetrainers-ltxv \
  --output_dir $OUTPUT_DIR \
  --nccl_timeout 1800 \
  --report_to wandb"

cmd="accelerate launch --config_file accelerate_configs/deepspeed.yaml --gpu_ids $GPU_IDS train.py \
  $model_cmd \
  $dataset_cmd \
  $dataloader_cmd \
  $diffusion_cmd \
  $training_cmd \
  $optimizer_cmd \
  $miscellaneous_cmd"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"

When the script is run for the first time, it OOMs when trying to move the VAE to accelerator.device after computing the conditions.

When the script is relaunched it errors out with:

NFO:finetrainers:Initializing models
INFO:finetrainers:Initializing precomputations
WARNING:finetrainers:Number of precomputed conditions (69) does not match number of precomputed latents (9).Cleaning up precomputed directories and re-running precomputation.
ERROR:finetrainers:An error occurred during training: [Errno 2] No such file or directory: 'video-dataset-disney/ltx_video_Lightricks-LTX-Video_precomputed/conditions/conditions-5-3.pt'
ERROR:finetrainers:An error occurred during training: [Errno 2] No such file or directory: 'video-dataset-disney/ltx_video_Lightricks-LTX-Video_precomputed/conditions/conditions-5-3.pt'
ERROR:finetrainers:An error occurred during training: [Errno 2] No such file or directory: 'video-dataset-disney/ltx_video_Lightricks-LTX-Video_precomputed/conditions/conditions-5-3.pt'
ERROR:finetrainers:An error occurred during training: [Errno 2] No such file or directory: 'video-dataset-disney/ltx_video_Lightricks-LTX-Video_precomputed/conditions/conditions-0-7.pt'
ERROR:finetrainers:An error occurred during training: [Errno 2] No such file or directory: 'video-dataset-disney/ltx_video_Lightricks-LTX-Video_precomputed/conditions/conditions-0-7.pt'
ERROR:finetrainers:An error occurred during training: [Errno 2] No such file or directory: 'video-dataset-disney/ltx_video_Lightricks-LTX-Video_precomputed/conditions/conditions-0-7.pt'
ERROR:finetrainers:An error occurred during training: [Errno 2] No such file or directory: 'video-dataset-disney/ltx_video_Lightricks-LTX-Video_precomputed/conditions/conditions-0-7.pt'
ERROR:finetrainers:Traceback (most recent call last):
  File "/fsx/sayak/collabs/finetrainers/train.py", line 28, in main
    trainer.prepare_precomputations()
  File "/fsx/sayak/collabs/finetrainers/finetrainers/trainer.py", line 251, in prepare_precomputations
    should_precompute = should_perform_precomputation(precomputation_dir)
  File "/fsx/sayak/collabs/finetrainers/finetrainers/utils/data_utils.py", line 27, in should_perform_precomputation
    file.unlink()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/pathlib.py", line 1206, in unlink
    self._accessor.unlink(self)
FileNotFoundError: [Errno 2] No such file or directory: 'video-dataset-disney/ltx_video_Lightricks-LTX-Video_precomputed/conditions/conditions-5-3.pt'

I am not sure if this is expected.

Some additional updates I had to perform:

patch
diff --git a/accelerate_configs/deepspeed.yaml b/accelerate_configs/deepspeed.yaml
index 62db0b4..67378fb 100644
--- a/accelerate_configs/deepspeed.yaml
+++ b/accelerate_configs/deepspeed.yaml
@@ -14,7 +14,7 @@ machine_rank: 0
 main_training_function: main
 mixed_precision: bf16
 num_machines: 1
-num_processes: 2
+num_processes: 8
 rdzv_backend: static
 same_network: true
 tpu_env: []
diff --git a/finetrainers/cogvideox/__init__.py b/finetrainers/cogvideox/__init__.py
index 6a3f826..62b4d15 100644
--- a/finetrainers/cogvideox/__init__.py
+++ b/finetrainers/cogvideox/__init__.py
@@ -1 +1 @@
-from .cogvideox_lora import COGVIDEOX_T2V_LORA_CONFIG
+from .cogvideox_lora import COGVIDEOX_T2V_LORA_CONFIG, COGVIDEOX_T2V_SFT_CONFIG
diff --git a/finetrainers/hunyuan_video/__init__.py b/finetrainers/hunyuan_video/__init__.py
index f4e780d..ab2d1a9 100644
--- a/finetrainers/hunyuan_video/__init__.py
+++ b/finetrainers/hunyuan_video/__init__.py
@@ -1 +1 @@
-from .hunyuan_video_lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG
+from .hunyuan_video_lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG, HUNYUAN_VIDEO_T2V_SFT_CONFIG
diff --git a/finetrainers/ltx_video/__init__.py b/finetrainers/ltx_video/__init__.py
index b583686..b58d14a 100644
--- a/finetrainers/ltx_video/__init__.py
+++ b/finetrainers/ltx_video/__init__.py
@@ -1 +1 @@
-from .ltx_video_lora import LTX_VIDEO_T2V_LORA_CONFIG
+from .ltx_video_lora import LTX_VIDEO_T2V_LORA_CONFIG, LTX_VIDEO_T2V_SFT_CONFIG
diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py
index d843387..929d647 100644
--- a/finetrainers/trainer.py
+++ b/finetrainers/trainer.py
@@ -326,6 +326,7 @@ class Trainer:
         # Precompute latents
         latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
         self._set_components(latent_components)
+        print("Moving for computing latents.")
         self._move_components_to_device()
 
         self._disable_grad_for_components(components=[self.vae])
@@ -1123,13 +1124,15 @@ class Trainer:
                 self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
 
     def _move_components_to_device(self):
+        print(f"{self.text_encoder is None=}, {self.vae is None=}, {self.transformer is None=}")
         if self.text_encoder is not None:
             self.text_encoder = self.text_encoder.to(self.state.accelerator.device)
         if self.text_encoder_2 is not None:
             self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device)
         if self.text_encoder_3 is not None:
             self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device)
-        if self.transformer is not None:
+        if self.transformer is not None and self.args.training_type != "sft":
+            # For SFT the `self.transformer` should be prepped by the `accelerator`
             self.transformer = self.transformer.to(self.state.accelerator.device)
         if self.unet is not None:
             self.unet = self.unet.to(self.state.accelerator.device)

@a-r-r-o-w
Copy link
Owner

Seems like a race condition. The cleaning up part should only done from the main process, but it's called from all process atm I think. This was never a problem for me because I am using the same precomputed conditions from the beginning. Here, I see that you had 9 precomputations already performed before the training instead of 69, presumably from a run you cancelled. The quickest fix would be to just delete the existing precomputed folder and starting again, and we can follow-up with a distributed fix for this case in separate PR

@a-r-r-o-w
Copy link
Owner

a-r-r-o-w commented Jan 8, 2025

-        if self.transformer is not None:
+        if self.transformer is not None and self.args.training_type != "sft":
+            # For SFT the `self.transformer` should be prepped by the `accelerator`

I'm not sure I understand this change. The accelerator only handles wrapping the module with DDP, but device transfer still needs to be done ourselves irrespective of training type, no?

@sayakpaul
Copy link
Collaborator

I'm not sure I understand this change. The accelerator only handles wrapping the module with DDP, but device transfer still needs to be done ourselves irrespective of training type, no?

Actually accelerate will take care of it. Here's a reference: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_flux.py.

but coming to the OOM problem, I think there was some infra issues on my end. Training seems have started https://wandb.ai/sayakpaul/finetrainers-ltxv/runs/ob6cgxf4.

Note the changes I had to make to:

  • Allow resuming and saving checkpoints.
  • Allow loading of precomputed latents.
patch
diff --git a/finetrainers/cogvideox/__init__.py b/finetrainers/cogvideox/__init__.py
index 6a3f826..62b4d15 100644
--- a/finetrainers/cogvideox/__init__.py
+++ b/finetrainers/cogvideox/__init__.py
@@ -1 +1 @@
-from .cogvideox_lora import COGVIDEOX_T2V_LORA_CONFIG
+from .cogvideox_lora import COGVIDEOX_T2V_LORA_CONFIG, COGVIDEOX_T2V_SFT_CONFIG
diff --git a/finetrainers/dataset.py b/finetrainers/dataset.py
index 6054e49..e83d2d7 100644
--- a/finetrainers/dataset.py
+++ b/finetrainers/dataset.py
@@ -353,13 +353,13 @@ class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset):
 
 
 class PrecomputedDataset(Dataset):
-    def __init__(self, data_root: str) -> None:
+    def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None:
         super().__init__()
 
         self.data_root = Path(data_root)
-
-        self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
-        self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME
+        precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
+        self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
+        self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
 
         self.latent_conditions = sorted(os.listdir(self.latents_path))
         self.text_conditions = sorted(os.listdir(self.conditions_path))
diff --git a/finetrainers/hunyuan_video/__init__.py b/finetrainers/hunyuan_video/__init__.py
index f4e780d..ab2d1a9 100644
--- a/finetrainers/hunyuan_video/__init__.py
+++ b/finetrainers/hunyuan_video/__init__.py
@@ -1 +1 @@
-from .hunyuan_video_lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG
+from .hunyuan_video_lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG, HUNYUAN_VIDEO_T2V_SFT_CONFIG
diff --git a/finetrainers/ltx_video/__init__.py b/finetrainers/ltx_video/__init__.py
index b583686..b58d14a 100644
--- a/finetrainers/ltx_video/__init__.py
+++ b/finetrainers/ltx_video/__init__.py
@@ -1 +1 @@
-from .ltx_video_lora import LTX_VIDEO_T2V_LORA_CONFIG
+from .ltx_video_lora import LTX_VIDEO_T2V_LORA_CONFIG, LTX_VIDEO_T2V_SFT_CONFIG
diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py
index d843387..377b3b2 100644
--- a/finetrainers/trainer.py
+++ b/finetrainers/trainer.py
@@ -252,7 +252,7 @@ class Trainer:
         if not should_precompute:
             logger.info("Precomputed conditions and latents found. Loading precomputed data.")
             self.dataloader = torch.utils.data.DataLoader(
-                PrecomputedDataset(self.args.data_root),
+                PrecomputedDataset(self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id),
                 batch_size=self.args.batch_size,
                 shuffle=True,
                 collate_fn=collate_fn,
@@ -433,6 +433,8 @@ class Trainer:
                 target_modules=self.args.target_modules,
             )
             self.transformer.add_adapter(transformer_lora_config)
+        else:
+            transformer_lora_config = None
 
         # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
         if self.args.allow_tf32 and torch.cuda.is_available():
@@ -452,7 +454,8 @@ class Trainer:
                         type(unwrap_model(self.state.accelerator, self.transformer)),
                     ):
                         model = unwrap_model(self.state.accelerator, model)
-                        transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+                        if self.args.training_type == "lora":
+                            transformer_lora_layers_to_save = get_peft_model_state_dict(model)
                     else:
                         raise ValueError(f"Unexpected save model: {model.__class__}")
 
@@ -460,10 +463,14 @@ class Trainer:
                     if weights:
                         weights.pop()
 
-                self.model_config["pipeline_cls"].save_lora_weights(
-                    output_dir,
-                    transformer_lora_layers=transformer_lora_layers_to_save,
-                )
+                # TODO: refactor if needed.
+                if self.args.training_type == "lora":
+                    self.model_config["pipeline_cls"].save_lora_weights(
+                        output_dir,
+                        transformer_lora_layers=transformer_lora_layers_to_save,
+                    )
+                else:
+                    model.save_pretrained(os.path.join(output_dir, "transformer"))
 
         def load_model_hook(models, input_dir):
             if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
@@ -479,31 +486,37 @@ class Trainer:
                             f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}"
                         )
             else:
-                transformer_ = unwrap_model(self.state.accelerator, self.transformer).__class__.from_pretrained(
-                    self.args.pretrained_model_name_or_path, subfolder="transformer"
-                )
+                transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__
+                if self.args.training_type == "lora":
+                    transformer_ = transformer_cls_.from_pretrained(
+                        self.args.pretrained_model_name_or_path, subfolder="transformer"
+                    )
+                else:
+                    transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer"))
+            
+            if self.args.training_type == "lora":
                 transformer_.add_adapter(transformer_lora_config)
 
-            lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)
-            transformer_state_dict = {
-                f'{k.replace("transformer.", "")}': v
-                for k, v in lora_state_dict.items()
-                if k.startswith("transformer.")
-            }
-            incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
-            if incompatible_keys is not None:
-                # check only for unexpected keys
-                unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
-                if unexpected_keys:
-                    logger.warning(
-                        f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
-                        f" {unexpected_keys}. "
-                    )
+                lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)
+                transformer_state_dict = {
+                    f'{k.replace("transformer.", "")}': v
+                    for k, v in lora_state_dict.items()
+                    if k.startswith("transformer.")
+                }
+                incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+                if incompatible_keys is not None:
+                    # check only for unexpected keys
+                    unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+                    if unexpected_keys:
+                        logger.warning(
+                            f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+                            f" {unexpected_keys}. "
+                        )
 
             # Make sure the trainable params are in float32. This is again needed since the base models
             # are in `weight_dtype`. More details:
             # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
-            if self.args.mixed_precision == "fp16":
+            if self.args.mixed_precision == "fp16" and self.args.training_type == "lora":
                 # only upcast trainable parameters (LoRA) into fp32
                 cast_training_params([transformer_])

Of course the patch is a temporary workaround to kickstart the training.

@sayakpaul
Copy link
Collaborator

If the changes look good, I will take a node for some time and start the run with exact command from #192 (comment)

@sayakpaul
Copy link
Collaborator

Had to account for a few more things:

Patch
diff --git a/finetrainers/cogvideox/__init__.py b/finetrainers/cogvideox/__init__.py
index 6a3f826..62b4d15 100644
--- a/finetrainers/cogvideox/__init__.py
+++ b/finetrainers/cogvideox/__init__.py
@@ -1 +1 @@
-from .cogvideox_lora import COGVIDEOX_T2V_LORA_CONFIG
+from .cogvideox_lora import COGVIDEOX_T2V_LORA_CONFIG, COGVIDEOX_T2V_SFT_CONFIG
diff --git a/finetrainers/dataset.py b/finetrainers/dataset.py
index 6054e49..e83d2d7 100644
--- a/finetrainers/dataset.py
+++ b/finetrainers/dataset.py
@@ -353,13 +353,13 @@ class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset):
 
 
 class PrecomputedDataset(Dataset):
-    def __init__(self, data_root: str) -> None:
+    def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None:
         super().__init__()
 
         self.data_root = Path(data_root)
-
-        self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
-        self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME
+        precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
+        self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
+        self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
 
         self.latent_conditions = sorted(os.listdir(self.latents_path))
         self.text_conditions = sorted(os.listdir(self.conditions_path))
diff --git a/finetrainers/hunyuan_video/__init__.py b/finetrainers/hunyuan_video/__init__.py
index f4e780d..ab2d1a9 100644
--- a/finetrainers/hunyuan_video/__init__.py
+++ b/finetrainers/hunyuan_video/__init__.py
@@ -1 +1 @@
-from .hunyuan_video_lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG
+from .hunyuan_video_lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG, HUNYUAN_VIDEO_T2V_SFT_CONFIG
diff --git a/finetrainers/ltx_video/__init__.py b/finetrainers/ltx_video/__init__.py
index b583686..b58d14a 100644
--- a/finetrainers/ltx_video/__init__.py
+++ b/finetrainers/ltx_video/__init__.py
@@ -1 +1 @@
-from .ltx_video_lora import LTX_VIDEO_T2V_LORA_CONFIG
+from .ltx_video_lora import LTX_VIDEO_T2V_LORA_CONFIG, LTX_VIDEO_T2V_SFT_CONFIG
diff --git a/finetrainers/ltx_video/ltx_video_lora.py b/finetrainers/ltx_video/ltx_video_lora.py
index 77ec50e..26ab8b1 100644
--- a/finetrainers/ltx_video/ltx_video_lora.py
+++ b/finetrainers/ltx_video/ltx_video_lora.py
@@ -40,13 +40,14 @@ def load_latent_models(
 
 def load_diffusion_models(
     model_id: str = "Lightricks/LTX-Video",
+    subfolder: str = "transformer",
     transformer_dtype: torch.dtype = torch.bfloat16,
     revision: Optional[str] = None,
     cache_dir: Optional[str] = None,
     **kwargs,
 ) -> Dict[str, nn.Module]:
     transformer = LTXVideoTransformer3DModel.from_pretrained(
-        model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
+        model_id, subfolder=subfolder, torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
     )
     scheduler = FlowMatchEulerDiscreteScheduler()
     return {"transformer": transformer, "scheduler": scheduler}
diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py
index d843387..e4d99df 100644
--- a/finetrainers/trainer.py
+++ b/finetrainers/trainer.py
@@ -252,7 +252,7 @@ class Trainer:
         if not should_precompute:
             logger.info("Precomputed conditions and latents found. Loading precomputed data.")
             self.dataloader = torch.utils.data.DataLoader(
-                PrecomputedDataset(self.args.data_root),
+                PrecomputedDataset(self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id),
                 batch_size=self.args.batch_size,
                 shuffle=True,
                 collate_fn=collate_fn,
@@ -433,6 +433,8 @@ class Trainer:
                 target_modules=self.args.target_modules,
             )
             self.transformer.add_adapter(transformer_lora_config)
+        else:
+            transformer_lora_config = None
 
         # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
         if self.args.allow_tf32 and torch.cuda.is_available():
@@ -452,7 +454,8 @@ class Trainer:
                         type(unwrap_model(self.state.accelerator, self.transformer)),
                     ):
                         model = unwrap_model(self.state.accelerator, model)
-                        transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+                        if self.args.training_type == "lora":
+                            transformer_lora_layers_to_save = get_peft_model_state_dict(model)
                     else:
                         raise ValueError(f"Unexpected save model: {model.__class__}")
 
@@ -460,10 +463,14 @@ class Trainer:
                     if weights:
                         weights.pop()
 
-                self.model_config["pipeline_cls"].save_lora_weights(
-                    output_dir,
-                    transformer_lora_layers=transformer_lora_layers_to_save,
-                )
+                # TODO: refactor if needed.
+                if self.args.training_type == "lora":
+                    self.model_config["pipeline_cls"].save_lora_weights(
+                        output_dir,
+                        transformer_lora_layers=transformer_lora_layers_to_save,
+                    )
+                else:
+                    model.save_pretrained(os.path.join(output_dir, "transformer"))
 
         def load_model_hook(models, input_dir):
             if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
@@ -479,31 +486,37 @@ class Trainer:
                             f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}"
                         )
             else:
-                transformer_ = unwrap_model(self.state.accelerator, self.transformer).__class__.from_pretrained(
-                    self.args.pretrained_model_name_or_path, subfolder="transformer"
-                )
+                transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__
+                if self.args.training_type == "lora":
+                    transformer_ = transformer_cls_.from_pretrained(
+                        self.args.pretrained_model_name_or_path, subfolder="transformer"
+                    )
+                else:
+                    transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer"))
+            
+            if self.args.training_type == "lora":
                 transformer_.add_adapter(transformer_lora_config)
 
-            lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)
-            transformer_state_dict = {
-                f'{k.replace("transformer.", "")}': v
-                for k, v in lora_state_dict.items()
-                if k.startswith("transformer.")
-            }
-            incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
-            if incompatible_keys is not None:
-                # check only for unexpected keys
-                unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
-                if unexpected_keys:
-                    logger.warning(
-                        f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
-                        f" {unexpected_keys}. "
-                    )
+                lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)
+                transformer_state_dict = {
+                    f'{k.replace("transformer.", "")}': v
+                    for k, v in lora_state_dict.items()
+                    if k.startswith("transformer.")
+                }
+                incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+                if incompatible_keys is not None:
+                    # check only for unexpected keys
+                    unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+                    if unexpected_keys:
+                        logger.warning(
+                            f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+                            f" {unexpected_keys}. "
+                        )
 
             # Make sure the trainable params are in float32. This is again needed since the base models
             # are in `weight_dtype`. More details:
             # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
-            if self.args.mixed_precision == "fp16":
+            if self.args.mixed_precision == "fp16" and self.args.training_type == "lora":
                 # only upcast trainable parameters (LoRA) into fp32
                 cast_training_params([transformer_])
 
@@ -886,12 +899,15 @@ class Trainer:
         if accelerator.is_main_process:
             # TODO: consider factoring this out when supporting other types of training algos.
             self.transformer = unwrap_model(accelerator, self.transformer)
-            transformer_lora_layers = get_peft_model_state_dict(self.transformer)
-
-            self.model_config["pipeline_cls"].save_lora_weights(
-                save_directory=self.args.output_dir,
-                transformer_lora_layers=transformer_lora_layers,
-            )
+            if self.args.training_type == "lora":
+                transformer_lora_layers = get_peft_model_state_dict(self.transformer)
+                self.model_config["pipeline_cls"].save_lora_weights(
+                    save_directory=self.args.output_dir,
+                    transformer_lora_layers=transformer_lora_layers,
+                )
+            else:
+                # TODO: The upcasting could be made CLI configurable.
+                self.transformer.to(torch.float32).save_pretrained(self.args.output_dir)
 
         self.validate(step=global_step, final_validation=True)
 
@@ -941,8 +957,16 @@ class Trainer:
         else:
             # `torch_dtype` is manually set within `initialize_pipeline()`.
             self._delete_components()
+            if self.args.training_type == "sft":
+                transformer = self.model_config["load_diffusion_models"](
+                    model_id=self.args.output_dir,
+                    subfolder=None,
+                )["transformer"]
+            else:
+                transformer = None
             pipeline = self.model_config["initialize_pipeline"](
                 model_id=self.args.pretrained_model_name_or_path,
+                transformer=transformer,
                 device=accelerator.device,
                 revision=self.args.revision,
                 cache_dir=self.args.cache_dir,
@@ -950,7 +974,8 @@ class Trainer:
                 enable_tiling=self.args.enable_tiling,
                 enable_model_cpu_offload=self.args.enable_model_cpu_offload,
             )
-            pipeline.load_lora_weights(self.args.output_dir)
+            if self.args.training_type == "lora":
+                pipeline.load_lora_weights(self.args.output_dir)
 
         all_processes_artifacts = []
         prompts_to_filenames = {}

Short run: https://wandb.ai/sayakpaul/finetrainers-ltxv/runs/vrpkc3wr

Command:
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export FINETRAINERS_LOG_LEVEL=DEBUG

GPU_IDS="0,1,2,3,4,5,6,7"

DATA_ROOT="video-dataset-disney"
CAPTION_COLUMN="prompt.txt"
VIDEO_COLUMN="videos.txt"
OUTPUT_DIR="ltxv_disney"

ID_TOKEN="BW_STYLE"

# Model arguments
model_cmd="--model_name ltx_video \
  --pretrained_model_name_or_path Lightricks/LTX-Video"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
  --video_column $VIDEO_COLUMN \
  --caption_column $CAPTION_COLUMN \
  --id_token $ID_TOKEN \
  --video_resolution_buckets 49x512x768 \
  --caption_dropout_p 0.05"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 0"

# Diffusion arguments
diffusion_cmd="--flow_resolution_shifting"

# Training arguments
training_cmd="--training_type sft \
  --seed 42 \
  --mixed_precision bf16 \
  --precompute_conditions \
  --batch_size 1 \
  --train_steps 20 \
  --gradient_accumulation_steps 4 \
  --gradient_checkpointing \
  --checkpointing_steps 10 \
  --checkpointing_limit 2 \
  --enable_slicing \
  --enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
  --lr 5e-6 \
  --lr_scheduler constant_with_warmup \
  --lr_warmup_steps 2000 \
  --lr_num_cycles 1 \
  --weight_decay 1e-2"

# Validation arguments
validation_prompts=$(cat <<EOF
$ID_TOKEN A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@49x512x768:::$ID_TOKEN A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage@@@49x512x768
EOF
)

validation_cmd="--validation_prompts \"$validation_prompts\" \
   --validation_steps 10 \
   --num_validation_videos 1"

# Miscellaneous arguments
miscellaneous_cmd="--tracker_name finetrainers-ltxv \
  --resume_from_checkpoint=latest \
  --output_dir $OUTPUT_DIR \
  --nccl_timeout 1800 \
  --report_to wandb"

cmd="accelerate launch --config_file accelerate_configs/deepspeed.yaml --gpu_ids $GPU_IDS train.py \
  $model_cmd \
  $dataset_cmd \
  $dataloader_cmd \
  $diffusion_cmd \
  $training_cmd \
  $optimizer_cmd \
  $validation_cmd \
  $miscellaneous_cmd"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"
Directory contents:
checkpoint-10                        final-22-0-2-BW_STYLE-A-black-and-whit.mp4       validation-20-0-2-BW_STYLE-A-black-and-whit.mp4
checkpoint-20                        final-22-1-2-BW_STYLE-A-woman-with-lon.mp4       validation-20-1-2-BW_STYLE-A-woman-with-lon.mp4
config.json                          validation-10-0-2-BW_STYLE-A-black-and-whit.mp4
diffusion_pytorch_model.safetensors  validation-10-1-2-BW_STYLE-A-woman-with-lon.mp4

@sayakpaul
Copy link
Collaborator

Thanks for the further changes! There's still a couple of things we need to do (as reflected in the patch in #192 (comment)). So, let's maybe wait a bit till @a-r-r-o-w replies.

@sayakpaul
Copy link
Collaborator

@a-r-r-o-w regardless of the PR, I think the precomputed dataset related changes need to be merged in soon:

diff --git a/finetrainers/dataset.py b/finetrainers/dataset.py
index 6054e49..e83d2d7 100644
--- a/finetrainers/dataset.py
+++ b/finetrainers/dataset.py
@@ -353,13 +353,13 @@ class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset):
 
 
 class PrecomputedDataset(Dataset):
-    def __init__(self, data_root: str) -> None:
+    def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None:
         super().__init__()
 
         self.data_root = Path(data_root)
-
-        self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
-        self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME
+        precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
+        self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
+        self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
 
         self.latent_conditions = sorted(os.listdir(self.latents_path))
         self.text_conditions = sorted(os.listdir(self.conditions_path))

diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py
index d843387..e4d99df 100644
--- a/finetrainers/trainer.py
+++ b/finetrainers/trainer.py
@@ -252,7 +252,7 @@ class Trainer:
         if not should_precompute:
             logger.info("Precomputed conditions and latents found. Loading precomputed data.")
             self.dataloader = torch.utils.data.DataLoader(
-                PrecomputedDataset(self.args.data_root),
+                PrecomputedDataset(self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id),
                 batch_size=self.args.batch_size,
                 shuffle=True,
                 collate_fn=collate_fn,

Can quickly submit a PR. LMK.

@a-r-r-o-w
Copy link
Owner

Actually accelerate will take care of it. Here's a reference:

Let's not make any changes that are not absolutely necessary for making this PR run. The handling of device change if training type is not full-finetune is not necessary IMO. If something is redundant, we can refactor in separate PR.

Note the changes I had to make to:

Allow resuming and saving checkpoints.

Don't we already support resuming/saving checkpoints. #106

Allow loading of precomputed latents.

Yes, let's do this in separate PR like you mentioned here. Thanks!

@sayakpaul, could you open a PR with your patch changes against this PR so that it's easier to review? That way it will be a bit more clear what the exact additional changes needed are. I think we have to handle the following:

  • Save/Load model hook
  • Validation should not load lora weights, so needs to be done differently
  • Anything else I've missed but you noted

I see in your patch that you handle both well, so that's nice, thanks.

@sayakpaul
Copy link
Collaborator

Yes, on it. Doing it.

Don't we already support resuming/saving checkpoints. #106

Didn't account for full fine-tuning as only LoRA was supported back then.

@sayakpaul
Copy link
Collaborator

Let's not make any changes that are not absolutely necessary for making this PR run. The handling of device change if training type is not full-finetune is not necessary IMO. If something is redundant, we can refactor in separate PR.

Yes, onboarded. No problems on that.

@a-r-r-o-w
Copy link
Owner

Didn't account for full fine-tuning as only LoRA was supported back then.

Ohhh okay, sorry my bad.

@sayakpaul
Copy link
Collaborator

@a-r-r-o-w I just pushed my updates from #202 to this branch. @ArEnSc I hope that is okay with you :)

I would suggest we do a round of testing for both LoRA and full fine-tuning before the merge to ensure LoRA was at least not broken because of the recent pushes (although it was already tested).

@ArEnSc
Copy link
Contributor Author

ArEnSc commented Jan 10, 2025

@a-r-r-o-w I just pushed my updates from #202 to this branch. @ArEnSc I hope that is okay with you :)

I would suggest we do a round of testing for both LoRA and full fine-tuning before the merge to ensure LoRA was at least not broken because of the recent pushes (although it was already tested).

yep I can do that I have a dataset to test this with

@sayakpaul
Copy link
Collaborator

Thanks, let's go! We will help too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants