Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full Finetuning for LTX possibily extended to other models. #192

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
eeea82a
Full Finetuning for LTX possibily extended to other models.
ArEnSc Jan 7, 2025
f0db0cc
Change name of the flag
ArEnSc Jan 7, 2025
4cd5a8e
Used disable grad for component on lora fine tuning enabled
ArEnSc Jan 7, 2025
cb9381b
Suggestions Addressed
ArEnSc Jan 7, 2025
bd02e6d
Merge branch 'main' into feature/full-finetuning
ArEnSc Jan 8, 2025
72ec207
Merge branch 'main' into feature/full-finetuning
sayakpaul Jan 8, 2025
d0ee9c3
Switching to Full FineTuning
ArEnSc Jan 9, 2025
19bba0a
Run linter.
ArEnSc Jan 9, 2025
acffc2d
parse subfolder when needed.
sayakpaul Jan 9, 2025
8188f8a
tackle saving and loading hooks.
sayakpaul Jan 9, 2025
5183405
tackle validation.
sayakpaul Jan 9, 2025
162e6cd
fix subfolder bug.
sayakpaul Jan 9, 2025
2c6f549
Merge branch 'main' into auxiliary-support-ff-2
sayakpaul Jan 9, 2025
28b3e84
Merge branch 'main' into auxiliary-support-ff-2
sayakpaul Jan 10, 2025
c0f3889
remove __class__.
sayakpaul Jan 10, 2025
6d59769
Merge branch 'main' into feature/full-finetuning
sayakpaul Jan 10, 2025
a422f7f
updates
sayakpaul Jan 10, 2025
34da4c5
Merge branch 'main' into feature/full-finetuning
sayakpaul Jan 10, 2025
014960f
Merge branch 'main' into feature/full-finetuning
sayakpaul Jan 10, 2025
d6821c3
refactor
a-r-r-o-w Jan 10, 2025
06dd96c
remove unnecessary changes
a-r-r-o-w Jan 11, 2025
1f304b3
handle saving of final model weights correctly
a-r-r-o-w Jan 11, 2025
491b35f
Merge branch 'main' into feature/full-finetuning
a-r-r-o-w Jan 11, 2025
ca957e5
remove unnecessary changes
a-r-r-o-w Jan 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion finetrainers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
"--training_type",
type=str,
default=None,
help="Type of training to perform. Choose between ['lora']",
help="Type of training to perform. Choose between ['lora','finetune']",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this here should it be called full_finetune?

)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can happen in another PR but we could also provide some info to the users know that if rank and other LoRA related arguments are provided alongside training_type=="sft", they will be ignored.

WDYT?

parser.add_argument(
Expand Down
15 changes: 15 additions & 0 deletions finetrainers/ltx_video/ltx_video_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,18 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
"forward_pass": forward_pass,
"validation": validation,
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know if you two wanted to enable this through swapping out prepare parameters I thought it might have been over engineering to do that so I just made a copy of this config.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is perfect and the intended usage. This will eventually be refactored out into model specs to add some syntactic sugar and make the code easier to follow

LTX_VIDEO_T2V_FT_CONFIG = {
"pipeline_cls": LTXPipeline,
"load_condition_models": load_condition_models,
"load_latent_models": load_latent_models,
"load_diffusion_models": load_diffusion_models,
"initialize_pipeline": initialize_pipeline,
"prepare_conditions": prepare_conditions,
"prepare_latents": prepare_latents,
"post_latent_preparation": post_latent_preparation,
"collate_fn": collate_fn_t2v,
"forward_pass": forward_pass,
"validation": validation,
}

3 changes: 2 additions & 1 deletion finetrainers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .cogvideox import COGVIDEOX_T2V_LORA_CONFIG
from .hunyuan_video import HUNYUAN_VIDEO_T2V_LORA_CONFIG
from .ltx_video import LTX_VIDEO_T2V_LORA_CONFIG
from .ltx_video import LTX_VIDEO_T2V_LORA_CONFIG, LTX_VIDEO_T2V_FT_CONFIG


SUPPORTED_MODEL_CONFIGS = {
Expand All @@ -11,6 +11,7 @@
},
"ltx_video": {
"lora": LTX_VIDEO_T2V_LORA_CONFIG,
"finetune": LTX_VIDEO_T2V_FT_CONFIG,
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
},
"cogvideox": {
"lora": COGVIDEOX_T2V_LORA_CONFIG,
Expand Down
71 changes: 44 additions & 27 deletions finetrainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def __init__(self, args: Args) -> None:
self.state.model_name = self.args.model_name
self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type)

# Components list
self.components = []
def prepare_dataset(self) -> None:
# TODO(aryan): Make a background process for fetching
logger.info("Initializing dataset and dataloader")
Expand Down Expand Up @@ -153,6 +155,17 @@ def _set_components(self, components: Dict[str, Any]) -> None:
self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config
self.vae_config = self.vae.config if self.vae is not None else self.vae_config

self.components = [self.tokenizer,
self.tokenizer_2,
self.tokenizer_3,
self.text_encoder,
self.text_encoder_2,
self.text_encoder_3,
self.transformer,
self.unet,
self.vae]


def _delete_components(self) -> None:
self.tokenizer = None
self.tokenizer_2 = None
Expand All @@ -167,6 +180,8 @@ def _delete_components(self) -> None:
free_memory()
torch.cuda.synchronize(self.state.accelerator.device)

self.components = None

def prepare_models(self) -> None:
logger.info("Initializing models")

Expand All @@ -189,6 +204,16 @@ def prepare_models(self) -> None:
if self.args.enable_tiling:
self.vae.enable_tiling()

def _disable_grad_for_components(self, components:list):
for component in components:
if component is not None:
component.requires_grad_(False)

def _enable_grad_for_components(self, components:list):
for component in components:
if component is not None:
component.requires_grad_(True)

def prepare_precomputations(self) -> None:
if not self.args.precompute_conditions:
return
Expand Down Expand Up @@ -237,16 +262,11 @@ def collate_fn(batch):
self._set_components(condition_components)
self._move_components_to_device()

# TODO(aryan): refactor later. for now only lora is supported
components_to_disable_grads = [
self._disable_grad_for_components(components=[
self.text_encoder,
self.text_encoder_2,
self.text_encoder_3,
]
for component in components_to_disable_grads:
if component is not None:
component.requires_grad_(False)

])
if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty":
logger.warning(
"Caption dropout is not supported with precomputation yet. This will be supported in the future."
Expand Down Expand Up @@ -300,12 +320,7 @@ def collate_fn(batch):
self._set_components(latent_components)
self._move_components_to_device()

# TODO(aryan): refactor later
components_to_disable_grads = [self.vae]
for component in components_to_disable_grads:
if component is not None:
component.requires_grad_(False)

self._disable_grad_for_components(components=[self.vae])
if self.vae is not None:
if self.args.enable_slicing:
self.vae.enable_slicing()
Expand Down Expand Up @@ -363,17 +378,18 @@ def prepare_trainable_parameters(self) -> None:
diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs())
self._set_components(diffusion_components)

# TODO(aryan): refactor later. for now only lora is supported
components_to_disable_grads = [
self._disable_grad_for_components(components=[
self.text_encoder,
self.text_encoder_2,
self.text_encoder_3,
self.transformer,
self.vae,
]
for component in components_to_disable_grads:
if component is not None:
component.requires_grad_(False)
])

if self.args.training_type == "full_finetune":
logger.info("Full Fine Tuning Enabled")
self._enable_grad_for_components(components=[self.transformer])
else:
logger.info("Lora Fine Tuning Enabled")

# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
Expand All @@ -398,13 +414,14 @@ def prepare_trainable_parameters(self) -> None:
if self.args.gradient_checkpointing:
self.transformer.enable_gradient_checkpointing()

transformer_lora_config = LoraConfig(
r=self.args.rank,
lora_alpha=self.args.lora_alpha,
init_lora_weights=True,
target_modules=self.args.target_modules,
)
self.transformer.add_adapter(transformer_lora_config)
if self.args.training_type == "lora":
transformer_lora_config = LoraConfig(
r=self.args.rank,
lora_alpha=self.args.lora_alpha,
init_lora_weights=True,
target_modules=self.args.target_modules,
)
self.transformer.add_adapter(transformer_lora_config)

# Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if self.args.allow_tf32 and torch.cuda.is_available():
Expand Down