Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LTX Image2Video LoRA #150

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

LTX Image2Video LoRA #150

wants to merge 3 commits into from

Conversation

a-r-r-o-w
Copy link
Owner

In order to run the following example, one needs atleast 24 GB VRAM (because of using 161 frames. if you set the resolution buckets to use lower amount of frames, vram requirements will be lower).

In order to run, one needs the same dataset format that we've been using so far:

  • prompts.txt
  • videos.txt
  • videos/

Since this is image-to-video training, one also needs validation images in addition to prompts. As an example, this can be done by simply taking the first frame of your training videos:

mkdir -p /path/to/dataset/images
cd /path/to/dataset/videos

# For Windows, you will have to convert this bash snippet to Powershell/CMD
for file in *; do ffmpeg -y -i "$file" -frames:v 1 "../images/${file%.*}.png"; done

Ideally, you should test with different starting images other than your training videos to verify if the LoRA works.

script
#!/bin/bash

# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0

GPU_IDS="2,3"

DATA_ROOT="/raid/aryan/video-dataset-disney"
CAPTION_COLUMN="prompts.txt"
VIDEO_COLUMN="videos.txt"

ID_TOKEN="afkx"

# Model arguments
model_cmd="--model_name ltx_video \
  --pretrained_model_name_or_path a-r-r-o-w/LTX-Video-diffusers"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
  --video_column $VIDEO_COLUMN \
  --caption_column $CAPTION_COLUMN \
  --id_token $ID_TOKEN \
  --video_resolution_buckets 49x512x768 161x512x768 \
  --precompute_conditions"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 0"

# Diffusion arguments
diffusion_cmd="--flow_resolution_shifting"

# Training arguments
training_cmd="--training_type lora_i2v \
  --seed 42 \
  --mixed_precision bf16 \
  --batch_size 1 \
  --train_steps 1000 \
  --target_modules to_q to_k to_v to_out.0 \
  --gradient_accumulation_steps 1 \
  --gradient_checkpointing \
  --checkpointing_steps 500 \
  --checkpointing_limit 2 \
  --enable_slicing \
  --enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
  --lr 1e-4 \
  --lr_scheduler cosine_with_restarts \
  --lr_warmup_steps 100 \
  --lr_num_cycles 1 \
  --beta1 0.9 \
  --beta2 0.95 \
  --weight_decay 1e-4 \
  --epsilon 1e-8 \
  --max_grad_norm 1.0"

# Validation arguments
validation_cmd="--validation_prompts \"$ID_TOKEN A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@49x512x768:::$ID_TOKEN A static black and white scene of three anthropomorphic characters is depicted in a series of animated frames. The first character, with oversized shoes, sings or speaks into a megaphone held by the central figure, who has an exaggerated open mouth. The third character, wearing a chef's hat, controls the megaphone's volume. Musical notes surround the characters, hinting at music, but there is no movement. The plain background focuses attention on the characters' interaction, maintaining a consistent composition throughout the scene.@@@49x512x768:::$ID_TOKEN A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@161x512x768:::$ID_TOKEN A static black and white scene of three anthropomorphic characters is depicted in a series of animated frames. The first character, with oversized shoes, sings or speaks into a megaphone held by the central figure, who has an exaggerated open mouth. The third character, wearing a chef's hat, controls the megaphone's volume. Musical notes surround the characters, hinting at music, but there is no movement. The plain background focuses attention on the characters' interaction, maintaining a consistent composition throughout the scene.@@@161x512x768:::$ID_TOKEN Three mechanical towers, each with unique designs on their tops, are connected by pipes on a flat surface against a cloudy black-and-white sky. Rain clouds appear, and water droplets fall, setting an industrial scene. The scene shifts to a ship's interior, where two cartoon characters interact amidst nautical elements. A claw-like appendage emerges, reaching for the ship, causing tension. The appendage extends, tilting the ship dramatically, introducing movement and a sense of urgency, highlighting potential danger within the ship's confines.@@@161x512x768:::$ID_TOKEN A tranquil, monochromatic scene unfolds with an animated character standing on a dock beside a body of water, a lighthouse visible in the background. The character is startled, running away from the viewer's perspective. A mechanical device with levers and a counter is introduced, attached to a rope, with a smaller, humanoid figure near it. The smaller figure engages with the machine, pulling the rope, while the larger figure remains in motion, creating a sense of tension and urgency in this classic, hand-drawn animated sequence.@@@161x512x768:::$ID_TOKEN A dynamic sequence unfolds on the deck of a ship, where a small, mouse-like character with large ears and short pants enthusiastically steers the vessel using a wheel. A larger, bulky character with a long pole engages in a playful confrontation, asserting dominance or playfully provoking the smaller one. Expressive gestures and movements convey emotions and intentions, set against a nautical backdrop featuring a steering wheel, life preserver, and bell. The two characters interact in a lively, competitive, or friendly exchange.@@@161x512x768:::$ID_TOKEN Mickey Mouse, with his distinctive round ears and black gloves, crouches attentively beside a pig character lying on the ground. The pig undergoes subtle movements, possibly preparing for action, while Mickey's posture and expression change as they interact. Mickey's mouth opens in surprise, and his body language suggests engagement. The pig raises its head or moves its limbs, indicating alertness or response. The background remains unchanged, with consistent lighting and shading highlighting the characters against a plain backdrop, capturing a series of moments of communication and reaction between the two.@@@161x512x768\" \
  --validation_images \"$DATA_ROOT/images/a3c275fc2eb0a67168a7c58a6a9adb14.png:::$DATA_ROOT/images/bf06573576ae0ea4d27a178b4d6e95a1.png:::$DATA_ROOT/images/a3c275fc2eb0a67168a7c58a6a9adb14.png:::$DATA_ROOT/images/bf06573576ae0ea4d27a178b4d6e95a1.png:::$DATA_ROOT/images/1094139d474e65852826d64a1b4aa520.png:::$DATA_ROOT/images/3108dd567bd8669967bc83e0bc50dab2.png:::$DATA_ROOT/images/1d50a3d9703f152758d5422c8b48010f.png:::$DATA_ROOT/images/12e51adf1acbf7acbb703a96a464a39b.png\" \
  --num_validation_videos 1 \
  --validation_steps 50"


REGULARIZATION_CMDS=("--caption_dropout_p 0.00" "--caption_dropout_p 0.05" "--caption_dropout_p 0.05 --image_condition_dropout_p 0.05 --image_condition_noise_scale 0.00" "--caption_dropout_p 0.05 --image_condition_dropout_p 0.05 --image_condition_noise_scale 0.1" "--caption_dropout_p 0.05 --image_condition_dropout_p 0.1 --image_condition_noise_scale 0.1")

LORA_CONFIGS=("--rank 128 --lora_alpha 64" "--rank 128 --lora_alpha 128" "--rank 128 --lora_alpha 256" "--rank 64 --lora_alpha 64" "--rank 256 --lora_alpha 256")

for regularization_cmd in "${REGULARIZATION_CMDS[@]}"
do
  sanitized_regularization_cmd=$(echo "$regularization_cmd" | sed 's/--//g; s/ /_/g; s/=/-/g')
  
  miscellaneous_cmd="--tracker_name finetrainers-ltxv \
    --output_dir /raid/aryan/ltx-video-$sanitized_regularization_cmd \
    --nccl_timeout 1800 \
    --report_to wandb"
  
  cmd="accelerate launch --config_file accelerate_configs/uncompiled_2.yaml --gpu_ids $GPU_IDS train.py \
    $model_cmd \
    $dataset_cmd \
    $dataloader_cmd \
    $diffusion_cmd \
    $training_cmd --rank 128 --lora_alpha 128 \
    $optimizer_cmd \
    $validation_cmd \
    $regularization_cmd \
    $lora_config \
    $miscellaneous_cmd"

  echo "Running command: $cmd"
  eval $cmd
  echo -ne "-------------------- Finished executing script --------------------\n\n"
done

for lora_config in "${LORA_CONFIGS[@]}"
do
  sanitized_lora_config=$(echo "$lora_config" | sed 's/--//g; s/ /_/g; s/=/-/g')
  regularization_cmd="--caption_dropout_p 0.05"
  
  miscellaneous_cmd="--tracker_name finetrainers-ltxv \
    --output_dir /raid/aryan/ltx-video-$sanitized_lora_config \
    --nccl_timeout 1800 \
    --report_to wandb"
  
  cmd="accelerate launch --config_file accelerate_configs/uncompiled_2.yaml --gpu_ids $GPU_IDS train.py \
    $model_cmd \
    $dataset_cmd \
    $dataloader_cmd \
    $diffusion_cmd \
    $training_cmd \
    $optimizer_cmd \
    $validation_cmd \
    $regularization_cmd \
    $lora_config \
    $miscellaneous_cmd"

  echo "Running command: $cmd"
  eval $cmd
  echo -ne "-------------------- Finished executing script --------------------\n\n"
done

Note that I don't know if the training works yet. I've queued some runs that should finish overnight if there were no bugs that would cause a crash.

@a-r-r-o-w a-r-r-o-w requested a review from sayakpaul December 25, 2024 21:35
finetrainers/trainer.py Outdated Show resolved Hide resolved
This was referenced Dec 25, 2024
Copy link
Collaborator

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@@ -11,6 +11,7 @@ class Args:
The arguments for the finetrainers training script.

Args:
TODO: write informational docstring
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this should be clubbed in separate PRs. Just TODO note is fine!

@@ -391,6 +392,41 @@ def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
)


def _add_regularization_arguments(parser: argparse.ArgumentParser) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful!

@@ -391,6 +392,41 @@ def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
)


def _add_regularization_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--caption_dropout_p",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's validate this to always be in [0, 1] if not already?

help="Technique to use for caption dropout.",
)
parser.add_argument(
"--image_condition_dropout_p",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

finetrainers/args.py Show resolved Hide resolved
else:
# Map from [0, 1] to [0, image_condition_noise_scale]
scale_factor = random.random() * image_condition_noise_scale
# :/ Because we don't have torch.randn_like
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we referring to this? Also, do we always have to add noise to the conditional latent? For Flux Control, we keep the conditional latent clean.

Copy link
Owner Author

@a-r-r-o-w a-r-r-o-w Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, torch.randn_like does not support using a generator, that's why I first create an empty tensor and then call normal_ to create the gaussian noise. I would like to try and maintain 100% reproducible runs so it is vital we always do this where needed, and try to reach that stage if not already

Here, I tried making an educated guess that adding noise to the image would serve as a good regularizer because we don't have information on how it was trained. I haven't done enough experiments to see that it yields any benefits so I will be removing this

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Flux as well, we chose to not add any noise, but typically this should be experimented with so that the model can learn to pay better attention to control signals even if they are noisy. I have seen this being done atleast a few places now, so assumed it would make sense but unless I run a large 10k-50k step run, it would be hard to evaluate its effect

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

Copy link
Collaborator

@sayakpaul sayakpaul Dec 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, IMO, it's okay to keep the argument for experimentation purposes. I will try to do this for Flux Control too and see if we get any effects there.

finetrainers/ltx_video/ltx_video_lora.py Show resolved Hide resolved
finetrainers/ltx_video/ltx_video_lora.py Show resolved Hide resolved
video_noisy_latents = (1.0 - sigmas) * latents[:, image_frame_end_offset:] + sigmas * noise[
:, image_frame_end_offset:
]
noisy_latents = torch.cat([image_latents, video_noisy_latents], dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be noisy_latents = torch.cat([video_noisy_latents, image_latents], dim=1)?

This is what we do in Cog, too:

noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cog is a different architecture that uses channel-wise concatenated latents. LTXV does not use channelwise-concatenation. Instead, we keep first frame clean and all other frames as noise for LTX. We also don't perform any denoising on the first clean frame (it is not part of the loss either), and only denoise the remaining frames

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, that is cool. Thanks for explaining. Then let’s make this a note?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and this -- suggests we should definitely experiment if noising the conditional latent is a good regularizer in this case.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yoavhacohen I made an assumption that the image latent is always kept clean and no denoising is applied to it. Is that the case when training LTXV? Or do you add a varying amount of noise based on some randomly sampled timestep (different from that of other frame latents) and perform denoising on it too? This is the first time I've seen this technique of per-token/per-frame denoising level so I'm not sure what to do without making guesses :/

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend training text-to-video and image-to-video models simultaneously, adding a varying amount of noise based on a randomly sampled small timestep.
For training with image conditioning, you just need to determine the noise scheduler to apply to the tokens that correspond to the conditioning frame (versus the other tokens).
I recommend making the implementation generic rather than specific to the first frame - it should support conditioning on any subset of tokens, not just those corresponding to the first frame.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! And generally speaking, the noise to be added to the conditioning frame — should it have a smaller magnitude than the one being added to the rest of the tokens?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yoavhacohen That sounds great, thank you! Since we don't support the idea of per-token timesteps in diffusers, I think we might have to write a custom scheduler step implementation - will give this a stab soon

Copy link

@yoavhacohen yoavhacohen Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The noise applied to conditioning tokens should be reduced compared to other tokens in the sequence - that's actually what defines them as conditioning tokens.

finetrainers/trainer.py Show resolved Hide resolved
# Map from [0, 1] to [0, image_condition_noise_scale]
scale_factor = random.random() * image_condition_noise_scale
# :/ Because we don't have torch.randn_like
latents[:, :, 0] = (
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yoavhacohen Would love to know your thoughts on this as well. It was an educated guess based on some other training code I've come across for image latent regularisation. Since we're not entirely sure how LTXV was trained, this may probably not be helpful and cause worse results. I'm still experimenting but any details of training for things like this would be super awesome 🤗

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add noise to the conditioning tokens in the same way as the other tokens - just use a different noise level.

image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]

# Note: we separately encode the image and video because there is a 4x compression applied. We only want to condition
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yoavhacohen, I'll make the update soon! Makes sense looking at it in hindsight, since this was mostly just guesswork

@sayakpaul sayakpaul mentioned this pull request Jan 2, 2025
2 tasks
@UsernamesAreSilly4
Copy link

What are some settings I can use if I want to train on an RTX 3060 12 GB VRAM? I read that if you turn off validation, you only need 11 GB VRAM. What other optimizations can I use?

@sayakpaul
Copy link
Collaborator

You could try gradient_checkpointing, precompute_conditions, use_8bit_bnb, deespeed.

@scarbain
Copy link

scarbain commented Jan 4, 2025

Can we test this PR or the comments about the TODO changes are breaking training and not worth testing until done ? :)

@a-r-r-o-w
Copy link
Owner Author

Hi @scarbain. Sorry, I haven't had the time to move this to completion yet. My last training run did not yield particularly interesting results, and I'm yet to address some comments from the original model author, so would recommend waiting until merged unless you have access to ample GPU resources for testing/debugging.

Will try and complete soon after a new suite of memory optimizations in the coming days

@scarbain
Copy link

Hi @scarbain. Sorry, I haven't had the time to move this to completion yet. My last training run did not yield particularly interesting results, and I'm yet to address some comments from the original model author, so would recommend waiting until merged unless you have access to ample GPU resources for testing/debugging.

Will try and complete soon after a new suite of memory optimizations in the coming days

Hi @a-r-r-o-w ! I'm sorry for asking again, I certainly don't want to pressure you on this because all your work on this repository is a great gift to the community and you should prioritise how you want. Do you have an approximate ETA for completing this PR ? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants