From 7570ffab4062a2cc5185bd934684464d7388401c Mon Sep 17 00:00:00 2001 From: mori360 Date: Fri, 16 Jan 2026 15:59:38 -0800 Subject: [PATCH 01/20] add lora folder --- torchtitan/experiments/lora/debug_model.toml | 97 +++++++ torchtitan/experiments/lora/lora.py | 272 +++++++++++++++++++ torchtitan/experiments/lora/train.py | 35 +++ 3 files changed, 404 insertions(+) create mode 100644 torchtitan/experiments/lora/debug_model.toml create mode 100644 torchtitan/experiments/lora/lora.py create mode 100644 torchtitan/experiments/lora/train.py diff --git a/torchtitan/experiments/lora/debug_model.toml b/torchtitan/experiments/lora/debug_model.toml new file mode 100644 index 0000000000..91d81992c8 --- /dev/null +++ b/torchtitan/experiments/lora/debug_model.toml @@ -0,0 +1,97 @@ +# LoRA Debug Model Configuration +# +# This config is specifically for LoRA fine-tuning experiments. +# Use with: torchtitan.experiments.lora.train +# +# Example: +# CONFIG_FILE="./torchtitan/experiments/lora/debug_model.toml" \ +# ./run_train.sh torchtitan.experiments.lora.train + +[job] +dump_folder = "./outputs" +description = "Llama 3 LoRA debug training" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "llama3" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# Enable LoRA converter for fine-tuning +converters = ["lora"] + +[lora] +# LoRA configuration (all values are optional, defaults shown below) +rank = 8 +alpha = 16.0 +dropout = 0.0 +apply_to_all_linears = true + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "Interleaved1F1B" +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/experiments/lora/lora.py b/torchtitan/experiments/lora/lora.py new file mode 100644 index 0000000000..abd8112e1b --- /dev/null +++ b/torchtitan/experiments/lora/lora.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.protocols.model_converter import register_model_converter +from torchtitan.tools.logging import logger + + +@dataclass +class LoRAConfig: + """Configuration for LoRA (Low-Rank Adaptation) fine-tuning. + + Args: + rank: Rank of the low-rank approximation. Default: 8. + alpha: Scaling factor for the low-rank approximation. Default: 16.0. + dropout: Dropout probability for LoRA layers. Default: 0.0. + apply_to_all_linears: If True, apply LoRA to all Linear layers. + If False, only apply to attention layers (wq, wk, wv, wo). Default: True. + """ + + rank: int = 8 + """Rank of the low-rank approximation""" + + alpha: float = 16.0 + """Scaling factor for the low-rank approximation""" + + dropout: float = 0.0 + """Dropout probability for LoRA layers""" + + apply_to_all_linears: bool = True + """If True, apply LoRA to all Linear layers. If False, only apply to attention layers.""" + + +def get_lora_config(job_config: JobConfig) -> LoRAConfig: + """Get LoRA config from job_config, using defaults if not specified. + + The LoRA config can be specified in the TOML file under [lora] section: + ```toml + [lora] + rank = 8 + alpha = 16.0 + dropout = 0.0 + apply_to_all_linears = true + ``` + + If not specified, default values from LoRAConfig will be used. + """ + lora_config = LoRAConfig() + + # Check if job_config has a 'lora' attribute (from custom config) + if hasattr(job_config, "lora"): + lora_section = job_config.lora + if hasattr(lora_section, "rank"): + lora_config.rank = lora_section.rank + if hasattr(lora_section, "alpha"): + lora_config.alpha = lora_section.alpha + if hasattr(lora_section, "dropout"): + lora_config.dropout = lora_section.dropout + if hasattr(lora_section, "apply_to_all_linears"): + lora_config.apply_to_all_linears = lora_section.apply_to_all_linears + + return lora_config + + +class LoRALinear(nn.Module): + """LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models `_. + + LoRA perturbs a given layer via a low-rank approximation where only + the rank decomposition matrices are trainable. In a linear layer instead of + :math:`x \\mapsto W_0x` a LoRALinear layer is defined as + :math:`x \\mapsto W_0x + (\\alpha / r)BAx`, where :math:`r` is the rank of + the matrices :math:`A` and :math:`B` and :math:`\\alpha` is a scaling factor. + As in the original implementation, we support dropout before multiplication + by the low-rank matrices. + + Args: + in_dim (int): input dimension + out_dim (int): output dimension + rank (int): rank of the low-rank approximation + alpha (float): scaling factor for the low-rank approximation + dropout (float): dropout probability. Default: 0.0 + use_bias (bool): whether to include bias in the original linear layer. + Default: False + + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + use_bias: bool = False, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.rank = rank + self.alpha = alpha + self.use_bias = use_bias + + # Setup weight and bias + linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=self.use_bias) + weight = linear.weight + bias = linear.bias if self.use_bias else None + + # 'self.disabled' is a flag showing whether to turn off LoRA adapters, + # this can be used in DPO for treating the lora adapters as the policy model + # and disabling it to treat the base model as the reference model + self.disabled = False + self.register_parameter("weight", nn.Parameter(weight)) + self.register_parameter( + "bias", nn.Parameter(bias) if bias is not None else None + ) + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() + self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) + self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) + self.merged = False + self.initialize_parameters() + + def to_empty( + self, *, device: Optional[Union[str, torch.device, int]], recurse: bool = True + ): + self.lora_a.to_empty(device=device, recurse=recurse) + self.lora_b.to_empty(device=device, recurse=recurse) + + def initialize_parameters(self): + # Initialize as in + # https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L119 + _lora_a_init_params(self.lora_a) + _lora_b_init_params(self.lora_b) + + def adapter_params(self) -> list[str]: + """ + Return a list of strings corresponding to the names of the ``nn.Parameter`` s in + the model coming from the adapter. + + For LoRA this means lora_a.weight and lora_b.weight. + """ + # NOTE: this function has to be updated if the names of "lora_a" and "lora_b" + # in this module change. + adapter_params = ["lora_a.weight", "lora_b.weight"] + return adapter_params + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape ``(..., in_dim)`` + + Returns: + torch.Tensor: output tensor with shape ``(..., out_dim)`` + + """ + out = F.linear(x, self.weight, self.bias) + if self.disabled: + return out + lora_out = self.lora_a(self.dropout(x)) + lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) + return out + lora_out + + +def _lora_a_init_params(x: nn.Linear) -> None: + """ + Initialize LoRA A weight to Kaiming uniform. + """ + nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5)) + + +def _lora_b_init_params(x: nn.Linear) -> None: + """ + Initialize LoRA B weight to zeros. + """ + nn.init.zeros_(x.weight) + + +class LoRAConverter: + """Model converter that adds LoRA adapters to Linear layers. + + This converter replaces nn.Linear layers with LoRALinear layers and sets + requires_grad=True only for LoRA parameters, freezing all other parameters. + + Configuration can be specified in the TOML file under [lora] section: + ```toml + [lora] + rank = 8 + alpha = 16.0 + dropout = 0.0 + apply_to_all_linears = true + ``` + """ + + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): + lora_config = get_lora_config(job_config) + self.rank = lora_config.rank + self.alpha = lora_config.alpha + self.dropout = lora_config.dropout + self.apply_to_all_linears = lora_config.apply_to_all_linears + + logger.info( + f"LoRA config: rank={self.rank}, alpha={self.alpha}, " + f"dropout={self.dropout}, apply_to_all_linears={self.apply_to_all_linears}" + ) + + def convert(self, model: nn.Module) -> None: + """Inplace conversion of the model to use LoRA adapters.""" + # First, freeze all parameters + for param in model.parameters(): + param.requires_grad = False + + # Collect all Linear layers to replace (to avoid modifying while iterating) + replacements = [] + for name, module in model.named_modules(): + for child_name, child in module.named_children(): + if isinstance(child, nn.Linear) and not isinstance(child, LoRALinear): + replacements.append((module, child_name, child)) + + # Replace Linear layers with LoRALinear + for parent_module, child_name, child in replacements: + lora_linear = LoRALinear( + in_dim=child.in_features, + out_dim=child.out_features, + rank=self.rank, + alpha=self.alpha, + dropout=self.dropout, + use_bias=child.bias is not None, + ) + # First move to the same device and dtype as the original weights + lora_linear = lora_linear.to( + device=child.weight.device, dtype=child.weight.dtype + ) + # Then copy the original weights (after dtype conversion) + lora_linear.weight.data.copy_(child.weight.data) + if child.bias is not None: + lora_linear.bias.data.copy_(child.bias.data) + # Replace the module + setattr(parent_module, child_name, lora_linear) + + # Enable gradients only for LoRA parameters + for name, param in model.named_parameters(): + if "lora_a" in name or "lora_b" in name: + param.requires_grad = True + + # Log the number of trainable parameters + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + logger.info( + f"LoRA adapters added. Trainable parameters: {trainable_params:,} / {total_params:,} " + f"({100 * trainable_params / total_params:.2f}%)" + ) + + def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]) -> None: + """Post-optimizer hook (no-op for LoRA).""" + pass + + +# Register the LoRA converter +register_model_converter(LoRAConverter, "lora") diff --git a/torchtitan/experiments/lora/train.py b/torchtitan/experiments/lora/train.py new file mode 100644 index 0000000000..477d40e228 --- /dev/null +++ b/torchtitan/experiments/lora/train.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +LoRA Training Entry Point + +This module provides a training entry point that enables LoRA (Low-Rank Adaptation) +for fine-tuning large language models. It imports the LoRA module to register the +LoRAConverter, then delegates to the main training logic. + +Usage: + Run training with LoRA enabled: + ``` + CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" \\ + ./run_train.sh --training.steps 10 torchtitan.experiments.lora.train + ``` + + Make sure to add "lora" to the model.converters list in your config: + ```toml + [model] + converters = ["lora"] + ``` +""" + +# Import LoRA module to register the LoRAConverter with the model converter registry +import torchtitan.experiments.lora.lora # noqa: F401 + +from torchtitan.train import main, Trainer + + +if __name__ == "__main__": + main(Trainer) From 463444eb26a8cf5b3d83bcea275671e78467f010 Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 20 Jan 2026 11:20:33 -0800 Subject: [PATCH 02/20] draft --- torchtitan/__init__.py | 3 + torchtitan/components/lora/__init__.py | 19 ++++ .../{experiments => components}/lora/lora.py | 24 ++--- torchtitan/config/job_config.py | 39 ++++++++ torchtitan/experiments/lora/debug_model.toml | 97 ------------------- torchtitan/experiments/lora/train.py | 35 ------- .../llama3/train_configs/debug_model.toml | 10 +- 7 files changed, 79 insertions(+), 148 deletions(-) create mode 100644 torchtitan/components/lora/__init__.py rename torchtitan/{experiments => components}/lora/lora.py (93%) delete mode 100644 torchtitan/experiments/lora/debug_model.toml delete mode 100644 torchtitan/experiments/lora/train.py diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py index 52c3ff3e22..8683c72492 100644 --- a/torchtitan/__init__.py +++ b/torchtitan/__init__.py @@ -9,6 +9,9 @@ # Import to register quantization modules. import torchtitan.components.quantization # noqa: F401 +# Import to register lora module. +import torchtitan.components.lora # noqa: F401 + try: __version__ = version("torchtitan") except Exception as e: diff --git a/torchtitan/components/lora/__init__.py b/torchtitan/components/lora/__init__.py new file mode 100644 index 0000000000..5d176ed805 --- /dev/null +++ b/torchtitan/components/lora/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.lora.lora import ( + get_lora_config, + LoRAConfig, + LoRAConverter, + LoRALinear, +) + +__all__ = [ + "get_lora_config", + "LoRAConfig", + "LoRAConverter", + "LoRALinear", +] diff --git a/torchtitan/experiments/lora/lora.py b/torchtitan/components/lora/lora.py similarity index 93% rename from torchtitan/experiments/lora/lora.py rename to torchtitan/components/lora/lora.py index abd8112e1b..ad0eef3022 100644 --- a/torchtitan/experiments/lora/lora.py +++ b/torchtitan/components/lora/lora.py @@ -57,21 +57,13 @@ def get_lora_config(job_config: JobConfig) -> LoRAConfig: If not specified, default values from LoRAConfig will be used. """ - lora_config = LoRAConfig() - - # Check if job_config has a 'lora' attribute (from custom config) - if hasattr(job_config, "lora"): - lora_section = job_config.lora - if hasattr(lora_section, "rank"): - lora_config.rank = lora_section.rank - if hasattr(lora_section, "alpha"): - lora_config.alpha = lora_section.alpha - if hasattr(lora_section, "dropout"): - lora_config.dropout = lora_section.dropout - if hasattr(lora_section, "apply_to_all_linears"): - lora_config.apply_to_all_linears = lora_section.apply_to_all_linears - - return lora_config + lora_section = job_config.lora + return LoRAConfig( + rank=lora_section.rank, + alpha=lora_section.alpha, + dropout=lora_section.dropout, + apply_to_all_linears=lora_section.apply_to_all_linears, + ) class LoRALinear(nn.Module): @@ -252,6 +244,8 @@ def convert(self, model: nn.Module) -> None: for name, param in model.named_parameters(): if "lora_a" in name or "lora_b" in name: param.requires_grad = True + else: + param.requires_grad = False # Log the number of trainable parameters total_params = sum(p.numel() for p in model.parameters()) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 108c38efba..c27a86fbb9 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -694,6 +694,44 @@ class ActivationCheckpoint: """ +@dataclass +class LoRA: + """Configuration for LoRA (Low-Rank Adaptation) fine-tuning. + + LoRA is a parameter-efficient fine-tuning technique that freezes the pretrained + model weights and injects trainable low-rank decomposition matrices into each + layer of the Transformer architecture. + + To enable LoRA, add "lora" to the model.converters list in your config: + [model] + converters = ["lora"] + + [lora] + rank = 8 + alpha = 16.0 + dropout = 0.0 + apply_to_all_linears = true + """ + + rank: int = 8 + """Rank of the low-rank approximation. Higher rank = more parameters but better quality.""" + + alpha: float = 16.0 + """ + Scaling factor for the low-rank approximation. + The LoRA output is scaled by (alpha / rank), so higher alpha means stronger LoRA effect. + """ + + dropout: float = 0.0 + """Dropout probability applied to the LoRA layers. 0.0 means no dropout.""" + + apply_to_all_linears: bool = True + """ + If True, apply LoRA to all nn.Linear layers in the model. + If False, only apply to attention layers (wq, wk, wv, wo). + """ + + @dataclass class Compile: enable: bool = False @@ -991,6 +1029,7 @@ class JobConfig: activation_checkpoint: ActivationCheckpoint = field( default_factory=ActivationCheckpoint ) + lora: LoRA = field(default_factory=LoRA) compile: Compile = field(default_factory=Compile) quantize: Quantize = field(default_factory=Quantize) comm: Comm = field(default_factory=Comm) diff --git a/torchtitan/experiments/lora/debug_model.toml b/torchtitan/experiments/lora/debug_model.toml deleted file mode 100644 index 91d81992c8..0000000000 --- a/torchtitan/experiments/lora/debug_model.toml +++ /dev/null @@ -1,97 +0,0 @@ -# LoRA Debug Model Configuration -# -# This config is specifically for LoRA fine-tuning experiments. -# Use with: torchtitan.experiments.lora.train -# -# Example: -# CONFIG_FILE="./torchtitan/experiments/lora/debug_model.toml" \ -# ./run_train.sh torchtitan.experiments.lora.train - -[job] -dump_folder = "./outputs" -description = "Llama 3 LoRA debug training" -print_config = false - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 10 -enable_memory_snapshot = false -save_memory_snapshot_folder = "memory_snapshot" - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "llama3" -flavor = "debugmodel" -# test folder with tokenizer.json, for debug purpose only -hf_assets_path = "./tests/assets/tokenizer" -# Enable LoRA converter for fine-tuning -converters = ["lora"] - -[lora] -# LoRA configuration (all values are optional, defaults shown below) -rank = 8 -alpha = 16.0 -dropout = 0.0 -apply_to_all_linears = true - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps -decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps -decay_type = "linear" -min_lr_factor = 0.0 - -[training] -local_batch_size = 8 -seq_len = 2048 -max_norm = 1.0 # grad norm clipping -steps = 10 -dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 -enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "Interleaved1F1B" -context_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 10 -last_save_model_only = false -export_dtype = "float32" -async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] - -[activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] -selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy - -[compile] -enable=false -components = ["model", "loss"] - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output"] - -[validation] -enable = false -dataset = "c4_validation" -freq = 5 -steps = 10 diff --git a/torchtitan/experiments/lora/train.py b/torchtitan/experiments/lora/train.py deleted file mode 100644 index 477d40e228..0000000000 --- a/torchtitan/experiments/lora/train.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -LoRA Training Entry Point - -This module provides a training entry point that enables LoRA (Low-Rank Adaptation) -for fine-tuning large language models. It imports the LoRA module to register the -LoRAConverter, then delegates to the main training logic. - -Usage: - Run training with LoRA enabled: - ``` - CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" \\ - ./run_train.sh --training.steps 10 torchtitan.experiments.lora.train - ``` - - Make sure to add "lora" to the model.converters list in your config: - ```toml - [model] - converters = ["lora"] - ``` -""" - -# Import LoRA module to register the LoRAConverter with the model converter registry -import torchtitan.experiments.lora.lora # noqa: F401 - -from torchtitan.train import main, Trainer - - -if __name__ == "__main__": - main(Trainer) diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 7760667edd..6bd5905882 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -22,7 +22,15 @@ name = "llama3" flavor = "debugmodel" # test folder with tokenizer.json, for debug purpose only hf_assets_path = "./tests/assets/tokenizer" -# converters = ["float8"] +# Enable LoRA converter for fine-tuning +converters = ["lora"] + +[lora] +# LoRA configuration (all values are optional, defaults shown below) +rank = 8 +alpha = 16.0 +dropout = 0.0 +apply_to_all_linears = true [optimizer] name = "AdamW" From d19cf5089fe53f2aa5acf6bbbb628f4f4c0e7b69 Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 20 Jan 2026 13:28:25 -0800 Subject: [PATCH 03/20] fix optim.step --- torchtitan/components/lora/lora.py | 45 +++++++++++++++++++ .../llama3/train_configs/debug_model.toml | 6 +-- torchtitan/train.py | 2 +- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/torchtitan/components/lora/lora.py b/torchtitan/components/lora/lora.py index ad0eef3022..a5c855faa3 100644 --- a/torchtitan/components/lora/lora.py +++ b/torchtitan/components/lora/lora.py @@ -128,6 +128,7 @@ def to_empty( ): self.lora_a.to_empty(device=device, recurse=recurse) self.lora_b.to_empty(device=device, recurse=recurse) + return self def initialize_parameters(self): # Initialize as in @@ -135,6 +136,14 @@ def initialize_parameters(self): _lora_a_init_params(self.lora_a) _lora_b_init_params(self.lora_b) + def reset_parameters(self): + """Reset LoRA parameters. Called by init_weights during model initialization.""" + _lora_a_init_params(self.lora_a) + _lora_b_init_params(self.lora_b) + # Ensure LoRA params have requires_grad=True after reset + self.lora_a.weight.requires_grad = True + self.lora_b.weight.requires_grad = True + def adapter_params(self) -> list[str]: """ Return a list of strings corresponding to the names of the ``nn.Parameter`` s in @@ -200,6 +209,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.alpha = lora_config.alpha self.dropout = lora_config.dropout self.apply_to_all_linears = lora_config.apply_to_all_linears + self._converted_model: Optional[nn.Module] = None logger.info( f"LoRA config: rank={self.rank}, alpha={self.alpha}, " @@ -247,6 +257,20 @@ def convert(self, model: nn.Module) -> None: else: param.requires_grad = False + # Store reference for post_optimizer_hook to re-initialize LoRA params + self._converted_model = model + + # Wrap the original init_weights to also initialize LoRA parameters + original_init_weights = model.init_weights + + def init_weights_with_lora(*args, **kwargs): + # Call the original init_weights + original_init_weights(*args, **kwargs) + # Initialize LoRA parameters and ensure requires_grad is set + self._init_lora_params(model) + + model.init_weights = init_weights_with_lora + # Log the number of trainable parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum( @@ -257,6 +281,27 @@ def convert(self, model: nn.Module) -> None: f"({100 * trainable_params / total_params:.2f}%)" ) + def _init_lora_params(self, model: nn.Module) -> None: + """Initialize LoRA parameters and set requires_grad after model initialization.""" + for name, module in model.named_modules(): + if isinstance(module, LoRALinear): + # Re-initialize LoRA parameters + module.reset_parameters() + + # Re-freeze base model params and unfreeze LoRA params + # This is necessary because init_weights may have touched some params + for name, param in model.named_parameters(): + if "lora_a" in name or "lora_b" in name: + param.requires_grad = True + else: + param.requires_grad = False + + # Log trainable parameter count after init + trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + logger.info(f"LoRA params initialized. Trainable parameters: {trainable_params:,}") + def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]) -> None: """Post-optimizer hook (no-op for LoRA).""" pass diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 6bd5905882..2b96426a82 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -6,12 +6,12 @@ print_config = false [profiling] enable_profiling = false save_traces_folder = "profile_trace" -profile_freq = 10 +profile_freq = 1000 enable_memory_snapshot = false save_memory_snapshot_folder = "memory_snapshot" [metrics] -log_freq = 1 +log_freq = 10 disable_color_printing = false enable_tensorboard = false save_tb_folder = "tb" @@ -47,7 +47,7 @@ min_lr_factor = 0.0 local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 1000 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] diff --git a/torchtitan/train.py b/torchtitan/train.py index 3c255ccdaa..2c8ae7a541 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -140,7 +140,7 @@ def __init__(self, job_config: JobConfig): f"with {json.dumps(dataclasses.asdict(model_args), indent=2, ensure_ascii=False)}" ) with ( - torch.device("meta"), + torch.device("cuda"), utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), ): # pyrefly: ignore[bad-instantiation] From bd389957d39f0c54d5f3334f9fce6b09fb3a8411 Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 20 Jan 2026 13:35:20 -0800 Subject: [PATCH 04/20] clean --- torchtitan/components/lora/lora.py | 30 +------------------ .../llama3/train_configs/debug_model.toml | 16 +++------- torchtitan/train.py | 2 +- 3 files changed, 6 insertions(+), 42 deletions(-) diff --git a/torchtitan/components/lora/lora.py b/torchtitan/components/lora/lora.py index a5c855faa3..1bd2cc1d55 100644 --- a/torchtitan/components/lora/lora.py +++ b/torchtitan/components/lora/lora.py @@ -29,33 +29,14 @@ class LoRAConfig: apply_to_all_linears: If True, apply LoRA to all Linear layers. If False, only apply to attention layers (wq, wk, wv, wo). Default: True. """ - rank: int = 8 - """Rank of the low-rank approximation""" - alpha: float = 16.0 - """Scaling factor for the low-rank approximation""" - dropout: float = 0.0 - """Dropout probability for LoRA layers""" - apply_to_all_linears: bool = True - """If True, apply LoRA to all Linear layers. If False, only apply to attention layers.""" def get_lora_config(job_config: JobConfig) -> LoRAConfig: """Get LoRA config from job_config, using defaults if not specified. - - The LoRA config can be specified in the TOML file under [lora] section: - ```toml - [lora] - rank = 8 - alpha = 16.0 - dropout = 0.0 - apply_to_all_linears = true - ``` - - If not specified, default values from LoRAConfig will be used. """ lora_section = job_config.lora return LoRAConfig( @@ -131,8 +112,6 @@ def to_empty( return self def initialize_parameters(self): - # Initialize as in - # https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L119 _lora_a_init_params(self.lora_a) _lora_b_init_params(self.lora_b) @@ -150,9 +129,8 @@ def adapter_params(self) -> list[str]: the model coming from the adapter. For LoRA this means lora_a.weight and lora_b.weight. + # TODO: update names when supporting EP """ - # NOTE: this function has to be updated if the names of "lora_a" and "lora_b" - # in this module change. adapter_params = ["lora_a.weight", "lora_b.weight"] return adapter_params @@ -174,16 +152,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _lora_a_init_params(x: nn.Linear) -> None: - """ - Initialize LoRA A weight to Kaiming uniform. - """ nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5)) def _lora_b_init_params(x: nn.Linear) -> None: - """ - Initialize LoRA B weight to zeros. - """ nn.init.zeros_(x.weight) diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 2b96426a82..7760667edd 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -6,12 +6,12 @@ print_config = false [profiling] enable_profiling = false save_traces_folder = "profile_trace" -profile_freq = 1000 +profile_freq = 10 enable_memory_snapshot = false save_memory_snapshot_folder = "memory_snapshot" [metrics] -log_freq = 10 +log_freq = 1 disable_color_printing = false enable_tensorboard = false save_tb_folder = "tb" @@ -22,15 +22,7 @@ name = "llama3" flavor = "debugmodel" # test folder with tokenizer.json, for debug purpose only hf_assets_path = "./tests/assets/tokenizer" -# Enable LoRA converter for fine-tuning -converters = ["lora"] - -[lora] -# LoRA configuration (all values are optional, defaults shown below) -rank = 8 -alpha = 16.0 -dropout = 0.0 -apply_to_all_linears = true +# converters = ["float8"] [optimizer] name = "AdamW" @@ -47,7 +39,7 @@ min_lr_factor = 0.0 local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 10 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] diff --git a/torchtitan/train.py b/torchtitan/train.py index 2c8ae7a541..3c255ccdaa 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -140,7 +140,7 @@ def __init__(self, job_config: JobConfig): f"with {json.dumps(dataclasses.asdict(model_args), indent=2, ensure_ascii=False)}" ) with ( - torch.device("cuda"), + torch.device("meta"), utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), ): # pyrefly: ignore[bad-instantiation] From e80de66a6788194d34a4e20e13336a09907343a5 Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 20 Jan 2026 13:52:36 -0800 Subject: [PATCH 05/20] polish logger --- torchtitan/components/lora/lora.py | 47 ++++++++++++++++-------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/torchtitan/components/lora/lora.py b/torchtitan/components/lora/lora.py index 1bd2cc1d55..29f8fda6de 100644 --- a/torchtitan/components/lora/lora.py +++ b/torchtitan/components/lora/lora.py @@ -38,12 +38,12 @@ class LoRAConfig: def get_lora_config(job_config: JobConfig) -> LoRAConfig: """Get LoRA config from job_config, using defaults if not specified. """ - lora_section = job_config.lora + lora_config = job_config.lora return LoRAConfig( - rank=lora_section.rank, - alpha=lora_section.alpha, - dropout=lora_section.dropout, - apply_to_all_linears=lora_section.apply_to_all_linears, + rank=lora_config.rank, + alpha=lora_config.alpha, + dropout=lora_config.dropout, + apply_to_all_linears=lora_config.apply_to_all_linears, ) @@ -101,7 +101,6 @@ def __init__( self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) - self.merged = False self.initialize_parameters() def to_empty( @@ -152,10 +151,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _lora_a_init_params(x: nn.Linear) -> None: + """ + Initialize LoRA A weight to Kaiming uniform. + """ nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5)) def _lora_b_init_params(x: nn.Linear) -> None: + """ + Initialize LoRA B weight to zeros. + """ nn.init.zeros_(x.weight) @@ -164,15 +169,6 @@ class LoRAConverter: This converter replaces nn.Linear layers with LoRALinear layers and sets requires_grad=True only for LoRA parameters, freezing all other parameters. - - Configuration can be specified in the TOML file under [lora] section: - ```toml - [lora] - rank = 8 - alpha = 16.0 - dropout = 0.0 - apply_to_all_linears = true - ``` """ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): @@ -184,7 +180,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self._converted_model: Optional[nn.Module] = None logger.info( - f"LoRA config: rank={self.rank}, alpha={self.alpha}, " + f"LoRA training active with rank={self.rank}, alpha={self.alpha}, " f"dropout={self.dropout}, apply_to_all_linears={self.apply_to_all_linears}" ) @@ -211,11 +207,11 @@ def convert(self, model: nn.Module) -> None: dropout=self.dropout, use_bias=child.bias is not None, ) - # First move to the same device and dtype as the original weights + # Move to the same device and dtype as the original weights lora_linear = lora_linear.to( device=child.weight.device, dtype=child.weight.dtype ) - # Then copy the original weights (after dtype conversion) + # Copy the original weights (after dtype conversion) lora_linear.weight.data.copy_(child.weight.data) if child.bias is not None: lora_linear.bias.data.copy_(child.bias.data) @@ -243,22 +239,27 @@ def init_weights_with_lora(*args, **kwargs): model.init_weights = init_weights_with_lora - # Log the number of trainable parameters + # Log conversion summary total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad ) logger.info( - f"LoRA adapters added. Trainable parameters: {trainable_params:,} / {total_params:,} " + f"Swapped to LoRALinear layers with {len(replacements)} linear modules converted" + ) + logger.info( + f"Trainable params: {trainable_params:,} / {total_params:,} " f"({100 * trainable_params / total_params:.2f}%)" ) def _init_lora_params(self, model: nn.Module) -> None: """Initialize LoRA parameters and set requires_grad after model initialization.""" + lora_layer_count = 0 for name, module in model.named_modules(): if isinstance(module, LoRALinear): # Re-initialize LoRA parameters module.reset_parameters() + lora_layer_count += 1 # Re-freeze base model params and unfreeze LoRA params # This is necessary because init_weights may have touched some params @@ -268,11 +269,13 @@ def _init_lora_params(self, model: nn.Module) -> None: else: param.requires_grad = False - # Log trainable parameter count after init trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad ) - logger.info(f"LoRA params initialized. Trainable parameters: {trainable_params:,}") + logger.info( + f"LoRA parameters initialized for {lora_layer_count} layers, " + f"trainable params: {trainable_params:,}" + ) def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]) -> None: """Post-optimizer hook (no-op for LoRA).""" From 0f56f28e0016bc55973375874cf2cb340d460a71 Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 20 Jan 2026 14:49:05 -0800 Subject: [PATCH 06/20] clean --- torchtitan/config/job_config.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index c27a86fbb9..e03ab9dd51 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -696,23 +696,6 @@ class ActivationCheckpoint: @dataclass class LoRA: - """Configuration for LoRA (Low-Rank Adaptation) fine-tuning. - - LoRA is a parameter-efficient fine-tuning technique that freezes the pretrained - model weights and injects trainable low-rank decomposition matrices into each - layer of the Transformer architecture. - - To enable LoRA, add "lora" to the model.converters list in your config: - [model] - converters = ["lora"] - - [lora] - rank = 8 - alpha = 16.0 - dropout = 0.0 - apply_to_all_linears = true - """ - rank: int = 8 """Rank of the low-rank approximation. Higher rank = more parameters but better quality.""" From bd5fd5105126c1842eec764096b6460dfb4be7bd Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 20 Jan 2026 15:12:25 -0800 Subject: [PATCH 07/20] lint --- torchtitan/components/lora/lora.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/torchtitan/components/lora/lora.py b/torchtitan/components/lora/lora.py index 29f8fda6de..ebd6778623 100644 --- a/torchtitan/components/lora/lora.py +++ b/torchtitan/components/lora/lora.py @@ -6,7 +6,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Union +from typing import cast, List, Optional, Union import torch import torch.nn as nn @@ -29,6 +29,7 @@ class LoRAConfig: apply_to_all_linears: If True, apply LoRA to all Linear layers. If False, only apply to attention layers (wq, wk, wv, wo). Default: True. """ + rank: int = 8 alpha: float = 16.0 dropout: float = 0.0 @@ -36,8 +37,7 @@ class LoRAConfig: def get_lora_config(job_config: JobConfig) -> LoRAConfig: - """Get LoRA config from job_config, using defaults if not specified. - """ + """Get LoRA config from job_config, using defaults if not specified.""" lora_config = job_config.lora return LoRAConfig( rank=lora_config.rank, @@ -142,7 +142,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.Tensor: output tensor with shape ``(..., out_dim)`` """ - out = F.linear(x, self.weight, self.bias) + out = F.linear( + x, cast(torch.Tensor, self.weight), cast(Optional[torch.Tensor], self.bias) + ) if self.disabled: return out lora_out = self.lora_a(self.dropout(x)) @@ -212,9 +214,9 @@ def convert(self, model: nn.Module) -> None: device=child.weight.device, dtype=child.weight.dtype ) # Copy the original weights (after dtype conversion) - lora_linear.weight.data.copy_(child.weight.data) - if child.bias is not None: - lora_linear.bias.data.copy_(child.bias.data) + cast(torch.Tensor, lora_linear.weight).data.copy_(child.weight.data) + if lora_linear.bias is not None: + cast(torch.Tensor, lora_linear.bias).data.copy_(child.bias.data) # Replace the module setattr(parent_module, child_name, lora_linear) @@ -233,17 +235,16 @@ def convert(self, model: nn.Module) -> None: def init_weights_with_lora(*args, **kwargs): # Call the original init_weights - original_init_weights(*args, **kwargs) + if callable(original_init_weights): + original_init_weights(*args, **kwargs) # Initialize LoRA parameters and ensure requires_grad is set self._init_lora_params(model) - model.init_weights = init_weights_with_lora + object.__setattr__(model, "init_weights", init_weights_with_lora) # Log conversion summary total_params = sum(p.numel() for p in model.parameters()) - trainable_params = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info( f"Swapped to LoRALinear layers with {len(replacements)} linear modules converted" ) @@ -269,9 +270,7 @@ def _init_lora_params(self, model: nn.Module) -> None: else: param.requires_grad = False - trainable_params = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info( f"LoRA parameters initialized for {lora_layer_count} layers, " f"trainable params: {trainable_params:,}" From 6995cd3ab6445c505b7998b2c72fd591156be0de Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 20 Jan 2026 15:34:48 -0800 Subject: [PATCH 08/20] remove reset --- torchtitan/components/lora/lora.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/torchtitan/components/lora/lora.py b/torchtitan/components/lora/lora.py index ebd6778623..4887221a9f 100644 --- a/torchtitan/components/lora/lora.py +++ b/torchtitan/components/lora/lora.py @@ -114,20 +114,10 @@ def initialize_parameters(self): _lora_a_init_params(self.lora_a) _lora_b_init_params(self.lora_b) - def reset_parameters(self): - """Reset LoRA parameters. Called by init_weights during model initialization.""" - _lora_a_init_params(self.lora_a) - _lora_b_init_params(self.lora_b) - # Ensure LoRA params have requires_grad=True after reset - self.lora_a.weight.requires_grad = True - self.lora_b.weight.requires_grad = True - def adapter_params(self) -> list[str]: """ Return a list of strings corresponding to the names of the ``nn.Parameter`` s in the model coming from the adapter. - - For LoRA this means lora_a.weight and lora_b.weight. # TODO: update names when supporting EP """ adapter_params = ["lora_a.weight", "lora_b.weight"] @@ -259,7 +249,7 @@ def _init_lora_params(self, model: nn.Module) -> None: for name, module in model.named_modules(): if isinstance(module, LoRALinear): # Re-initialize LoRA parameters - module.reset_parameters() + module.initialize_parameters() lora_layer_count += 1 # Re-freeze base model params and unfreeze LoRA params From d929ecda01a134468bc694454a09604cc66d5966 Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 20 Jan 2026 16:17:29 -0800 Subject: [PATCH 09/20] lint --- torchtitan/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py index 8683c72492..efb781da17 100644 --- a/torchtitan/__init__.py +++ b/torchtitan/__init__.py @@ -6,12 +6,12 @@ from importlib.metadata import version -# Import to register quantization modules. -import torchtitan.components.quantization # noqa: F401 - # Import to register lora module. import torchtitan.components.lora # noqa: F401 +# Import to register quantization modules. +import torchtitan.components.quantization # noqa: F401 + try: __version__ = version("torchtitan") except Exception as e: From b4f0587e182db1f679a1e40bb0fbd15b429b1757 Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 21 Jan 2026 15:15:16 -0800 Subject: [PATCH 10/20] optimize memory --- torchtitan/components/lora/lora.py | 249 ++++++++++++++++------------- 1 file changed, 141 insertions(+), 108 deletions(-) diff --git a/torchtitan/components/lora/lora.py b/torchtitan/components/lora/lora.py index 4887221a9f..a8ccc9b3b9 100644 --- a/torchtitan/components/lora/lora.py +++ b/torchtitan/components/lora/lora.py @@ -6,7 +6,7 @@ import math from dataclasses import dataclass -from typing import cast, List, Optional, Union +from typing import Optional, Union import torch import torch.nn as nn @@ -47,26 +47,73 @@ def get_lora_config(job_config: JobConfig) -> LoRAConfig: ) +class _LoRALinearFunction(torch.autograd.Function): + """Memory-efficient LoRA linear computation. + + Forward: out = X @ W.T + bias + scale * (X @ A.T @ B.T) + + Memory optimizations: + - Only saves X, A, B for backward + - Uses in-place addmm_ operations + """ + + @staticmethod + def forward(ctx, X, W, bias, A, B, scale): # type: ignore[override] + orig_shape = X.shape + X_2d = X.view(-1, X.shape[-1]) if X.dim() == 3 else X + + out = torch.empty(X_2d.shape[0], W.shape[0], dtype=X.dtype, device=X.device) + torch.mm(X_2d, W.t(), out=out) + + if bias is not None: + out.add_(bias) + + out.addmm_(X_2d @ A.T, B.T, alpha=scale) + + if X.dim() == 3: + out = out.view(orig_shape[0], orig_shape[1], -1) + + ctx.custom_saved_tensors = (W, scale) + ctx.save_for_backward(A, B, X) + ctx.has_bias = bias is not None + return out + + @staticmethod + def backward(ctx, dY): # type: ignore[override] + W, scale = ctx.custom_saved_tensors + A, B, X = ctx.saved_tensors + + batch, seq_len, hd = X.shape + dY = dY.reshape(-1, dY.shape[-1]) + X = X.reshape(-1, X.shape[-1]) + A, B = A.t(), B.t() + + d_A = torch.empty_like(A) + d_B = torch.empty_like(B) + d_A.addmm_(X.t(), dY @ B.t(), alpha=scale, beta=0) + d_B.addmm_(A.t() @ X.t(), dY, alpha=scale, beta=0) + + dX = dY @ W + dX.addmm_(dY @ B.t(), A.t(), alpha=scale) + d_bias = dY.sum(dim=0) if ctx.has_bias else None + + return dX.view(batch, seq_len, hd), None, d_bias, d_A.t(), d_B.t(), None + + class LoRALinear(nn.Module): - """LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models `_. + """LoRA linear layer. - LoRA perturbs a given layer via a low-rank approximation where only - the rank decomposition matrices are trainable. In a linear layer instead of - :math:`x \\mapsto W_0x` a LoRALinear layer is defined as - :math:`x \\mapsto W_0x + (\\alpha / r)BAx`, where :math:`r` is the rank of - the matrices :math:`A` and :math:`B` and :math:`\\alpha` is a scaling factor. - As in the original implementation, we support dropout before multiplication - by the low-rank matrices. + Implements: x -> W_0 @ x + (alpha / rank) * B @ A @ x - Args: - in_dim (int): input dimension - out_dim (int): output dimension - rank (int): rank of the low-rank approximation - alpha (float): scaling factor for the low-rank approximation - dropout (float): dropout probability. Default: 0.0 - use_bias (bool): whether to include bias in the original linear layer. - Default: False + See: https://arxiv.org/abs/2106.09685 + Args: + in_dim: Input dimension. + out_dim: Output dimension. + rank: Rank of the low-rank approximation. + alpha: Scaling factor. + dropout: Dropout probability. + use_bias: Whether to include bias. """ def __init__( @@ -77,6 +124,9 @@ def __init__( alpha: float, dropout: float = 0.0, use_bias: bool = False, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, ): super().__init__() self.in_dim = in_dim @@ -84,24 +134,40 @@ def __init__( self.rank = rank self.alpha = alpha self.use_bias = use_bias + self.disabled = False - # Setup weight and bias - linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=self.use_bias) - weight = linear.weight - bias = linear.bias if self.use_bias else None + # Setup weight - reuse provided tensor or create on meta device + if weight is not None: + self.register_parameter("weight", nn.Parameter(weight, requires_grad=False)) + else: + self.register_parameter( + "weight", + nn.Parameter( + torch.empty(out_dim, in_dim, device="meta", dtype=dtype), + requires_grad=False, + ), + ) + + # Setup bias + if use_bias: + if bias is not None: + self.register_parameter("bias", nn.Parameter(bias, requires_grad=False)) + else: + self.register_parameter( + "bias", + nn.Parameter( + torch.empty(out_dim, device="meta", dtype=dtype), + requires_grad=False, + ), + ) + else: + self.register_parameter("bias", None) - # 'self.disabled' is a flag showing whether to turn off LoRA adapters, - # this can be used in DPO for treating the lora adapters as the policy model - # and disabling it to treat the base model as the reference model - self.disabled = False - self.register_parameter("weight", nn.Parameter(weight)) - self.register_parameter( - "bias", nn.Parameter(bias) if bias is not None else None - ) self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() - self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) - self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) - self.initialize_parameters() + + # LoRA layers on meta device + self.lora_a = nn.Linear(in_dim, rank, bias=False, device="meta", dtype=dtype) + self.lora_b = nn.Linear(rank, out_dim, bias=False, device="meta", dtype=dtype) def to_empty( self, *, device: Optional[Union[str, torch.device, int]], recurse: bool = True @@ -115,53 +181,35 @@ def initialize_parameters(self): _lora_b_init_params(self.lora_b) def adapter_params(self) -> list[str]: - """ - Return a list of strings corresponding to the names of the ``nn.Parameter`` s in - the model coming from the adapter. - # TODO: update names when supporting EP - """ - adapter_params = ["lora_a.weight", "lora_b.weight"] - return adapter_params + """Return names of LoRA adapter parameters.""" + return ["lora_a.weight", "lora_b.weight"] def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x (torch.Tensor): input tensor with shape ``(..., in_dim)`` - - Returns: - torch.Tensor: output tensor with shape ``(..., out_dim)`` - - """ - out = F.linear( - x, cast(torch.Tensor, self.weight), cast(Optional[torch.Tensor], self.bias) - ) if self.disabled: - return out - lora_out = self.lora_a(self.dropout(x)) - lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) - return out + lora_out + return F.linear(x, self.weight, self.bias) # type: ignore[arg-type] + + return _LoRALinearFunction.apply( + self.dropout(x), + self.weight, + self.bias, + self.lora_a.weight, + self.lora_b.weight, + self.alpha / self.rank, + ) def _lora_a_init_params(x: nn.Linear) -> None: - """ - Initialize LoRA A weight to Kaiming uniform. - """ + """Initialize LoRA A weight to Kaiming uniform.""" nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5)) def _lora_b_init_params(x: nn.Linear) -> None: - """ - Initialize LoRA B weight to zeros. - """ + """Initialize LoRA B weight to zeros.""" nn.init.zeros_(x.weight) class LoRAConverter: - """Model converter that adds LoRA adapters to Linear layers. - - This converter replaces nn.Linear layers with LoRALinear layers and sets - requires_grad=True only for LoRA parameters, freezing all other parameters. - """ + """Model converter that adds LoRA adapters to Linear layers.""" def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): lora_config = get_lora_config(job_config) @@ -178,87 +226,72 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): def convert(self, model: nn.Module) -> None: """Inplace conversion of the model to use LoRA adapters.""" - # First, freeze all parameters - for param in model.parameters(): - param.requires_grad = False - - # Collect all Linear layers to replace (to avoid modifying while iterating) replacements = [] - for name, module in model.named_modules(): + for module in model.modules(): for child_name, child in module.named_children(): if isinstance(child, nn.Linear) and not isinstance(child, LoRALinear): replacements.append((module, child_name, child)) - # Replace Linear layers with LoRALinear for parent_module, child_name, child in replacements: + has_bias = child.bias is not None + original_weight = child.weight.data + original_bias = child.bias.data if has_bias else None + + # Break reference chain before creating new module + child.weight = None # type: ignore[assignment] + if has_bias: + child.bias = None # type: ignore[assignment] + lora_linear = LoRALinear( in_dim=child.in_features, out_dim=child.out_features, rank=self.rank, alpha=self.alpha, dropout=self.dropout, - use_bias=child.bias is not None, - ) - # Move to the same device and dtype as the original weights - lora_linear = lora_linear.to( - device=child.weight.device, dtype=child.weight.dtype + use_bias=has_bias, + weight=original_weight, + bias=original_bias, + dtype=child.weight.dtype + if child.weight is not None + else original_weight.dtype, ) - # Copy the original weights (after dtype conversion) - cast(torch.Tensor, lora_linear.weight).data.copy_(child.weight.data) - if lora_linear.bias is not None: - cast(torch.Tensor, lora_linear.bias).data.copy_(child.bias.data) - # Replace the module setattr(parent_module, child_name, lora_linear) - # Enable gradients only for LoRA parameters - for name, param in model.named_parameters(): - if "lora_a" in name or "lora_b" in name: - param.requires_grad = True - else: - param.requires_grad = False - - # Store reference for post_optimizer_hook to re-initialize LoRA params + self._set_lora_requires_grad(model) self._converted_model = model - # Wrap the original init_weights to also initialize LoRA parameters + # Wrap init_weights to also initialize LoRA parameters original_init_weights = model.init_weights def init_weights_with_lora(*args, **kwargs): - # Call the original init_weights if callable(original_init_weights): original_init_weights(*args, **kwargs) - # Initialize LoRA parameters and ensure requires_grad is set self._init_lora_params(model) object.__setattr__(model, "init_weights", init_weights_with_lora) - # Log conversion summary total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info( - f"Swapped to LoRALinear layers with {len(replacements)} linear modules converted" - ) + logger.info(f"Converted {len(replacements)} linear modules to LoRALinear") logger.info( f"Trainable params: {trainable_params:,} / {total_params:,} " f"({100 * trainable_params / total_params:.2f}%)" ) + def _set_lora_requires_grad(self, model: nn.Module) -> None: + """Set requires_grad: True for LoRA params, False for others.""" + for name, param in model.named_parameters(): + param.requires_grad = "lora_a" in name or "lora_b" in name + def _init_lora_params(self, model: nn.Module) -> None: - """Initialize LoRA parameters and set requires_grad after model initialization.""" + """Initialize LoRA parameters after model initialization.""" lora_layer_count = 0 - for name, module in model.named_modules(): + for module in model.modules(): if isinstance(module, LoRALinear): - # Re-initialize LoRA parameters module.initialize_parameters() lora_layer_count += 1 - # Re-freeze base model params and unfreeze LoRA params - # This is necessary because init_weights may have touched some params - for name, param in model.named_parameters(): - if "lora_a" in name or "lora_b" in name: - param.requires_grad = True - else: - param.requires_grad = False + self._set_lora_requires_grad(model) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info( @@ -266,7 +299,7 @@ def _init_lora_params(self, model: nn.Module) -> None: f"trainable params: {trainable_params:,}" ) - def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]) -> None: + def post_optimizer_hook(self, model: Union[nn.Module, list[nn.Module]]) -> None: """Post-optimizer hook (no-op for LoRA).""" pass From 40e6963f5a3925a617aec56a34498e0b28f02591 Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 21 Jan 2026 15:53:42 -0800 Subject: [PATCH 11/20] remove apply to all linears --- torchtitan/components/lora/lora.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/lora/lora.py b/torchtitan/components/lora/lora.py index a8ccc9b3b9..840fbf87db 100644 --- a/torchtitan/components/lora/lora.py +++ b/torchtitan/components/lora/lora.py @@ -26,14 +26,12 @@ class LoRAConfig: rank: Rank of the low-rank approximation. Default: 8. alpha: Scaling factor for the low-rank approximation. Default: 16.0. dropout: Dropout probability for LoRA layers. Default: 0.0. - apply_to_all_linears: If True, apply LoRA to all Linear layers. - If False, only apply to attention layers (wq, wk, wv, wo). Default: True. + TODO: add support to layers to apply, e.g. only attention layers or all linear. """ rank: int = 8 alpha: float = 16.0 dropout: float = 0.0 - apply_to_all_linears: bool = True def get_lora_config(job_config: JobConfig) -> LoRAConfig: @@ -43,7 +41,6 @@ def get_lora_config(job_config: JobConfig) -> LoRAConfig: rank=lora_config.rank, alpha=lora_config.alpha, dropout=lora_config.dropout, - apply_to_all_linears=lora_config.apply_to_all_linears, ) @@ -216,12 +213,11 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.rank = lora_config.rank self.alpha = lora_config.alpha self.dropout = lora_config.dropout - self.apply_to_all_linears = lora_config.apply_to_all_linears self._converted_model: Optional[nn.Module] = None logger.info( f"LoRA training active with rank={self.rank}, alpha={self.alpha}, " - f"dropout={self.dropout}, apply_to_all_linears={self.apply_to_all_linears}" + f"dropout={self.dropout}" ) def convert(self, model: nn.Module) -> None: From 3d34310b3c54361f6b2007616ea926f1cecd682d Mon Sep 17 00:00:00 2001 From: mori360 Date: Wed, 21 Jan 2026 16:16:04 -0800 Subject: [PATCH 12/20] remove apply to all linears --- torchtitan/config/job_config.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index e03ab9dd51..ec89ae39a4 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -708,12 +708,6 @@ class LoRA: dropout: float = 0.0 """Dropout probability applied to the LoRA layers. 0.0 means no dropout.""" - apply_to_all_linears: bool = True - """ - If True, apply LoRA to all nn.Linear layers in the model. - If False, only apply to attention layers (wq, wk, wv, wo). - """ - @dataclass class Compile: From ecb1b0305c6569f464c46c021fc34dea40340b16 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 22 Jan 2026 09:26:26 -0800 Subject: [PATCH 13/20] remove lora config --- torchtitan/components/lora/__init__.py | 4 --- torchtitan/components/lora/lora.py | 34 +++----------------------- 2 files changed, 3 insertions(+), 35 deletions(-) diff --git a/torchtitan/components/lora/__init__.py b/torchtitan/components/lora/__init__.py index 5d176ed805..0f60b90a4c 100644 --- a/torchtitan/components/lora/__init__.py +++ b/torchtitan/components/lora/__init__.py @@ -5,15 +5,11 @@ # LICENSE file in the root directory of this source tree. from torchtitan.components.lora.lora import ( - get_lora_config, - LoRAConfig, LoRAConverter, LoRALinear, ) __all__ = [ - "get_lora_config", - "LoRAConfig", "LoRAConverter", "LoRALinear", ] diff --git a/torchtitan/components/lora/lora.py b/torchtitan/components/lora/lora.py index 840fbf87db..2f7b6500dd 100644 --- a/torchtitan/components/lora/lora.py +++ b/torchtitan/components/lora/lora.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import math -from dataclasses import dataclass from typing import Optional, Union import torch @@ -18,32 +17,6 @@ from torchtitan.tools.logging import logger -@dataclass -class LoRAConfig: - """Configuration for LoRA (Low-Rank Adaptation) fine-tuning. - - Args: - rank: Rank of the low-rank approximation. Default: 8. - alpha: Scaling factor for the low-rank approximation. Default: 16.0. - dropout: Dropout probability for LoRA layers. Default: 0.0. - TODO: add support to layers to apply, e.g. only attention layers or all linear. - """ - - rank: int = 8 - alpha: float = 16.0 - dropout: float = 0.0 - - -def get_lora_config(job_config: JobConfig) -> LoRAConfig: - """Get LoRA config from job_config, using defaults if not specified.""" - lora_config = job_config.lora - return LoRAConfig( - rank=lora_config.rank, - alpha=lora_config.alpha, - dropout=lora_config.dropout, - ) - - class _LoRALinearFunction(torch.autograd.Function): """Memory-efficient LoRA linear computation. @@ -209,10 +182,9 @@ class LoRAConverter: """Model converter that adds LoRA adapters to Linear layers.""" def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): - lora_config = get_lora_config(job_config) - self.rank = lora_config.rank - self.alpha = lora_config.alpha - self.dropout = lora_config.dropout + self.rank = job_config.lora.rank + self.alpha = job_config.lora.alpha + self.dropout = job_config.lora.dropout self._converted_model: Optional[nn.Module] = None logger.info( From 8da82cd3c04f3d2b3655ad8e30a4940d416ce3d5 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 22 Jan 2026 15:19:24 -0800 Subject: [PATCH 14/20] remove autograd --- torchtitan/components/lora/lora.py | 75 +++++------------------------- 1 file changed, 12 insertions(+), 63 deletions(-) diff --git a/torchtitan/components/lora/lora.py b/torchtitan/components/lora/lora.py index 2f7b6500dd..74b174c30d 100644 --- a/torchtitan/components/lora/lora.py +++ b/torchtitan/components/lora/lora.py @@ -17,59 +17,6 @@ from torchtitan.tools.logging import logger -class _LoRALinearFunction(torch.autograd.Function): - """Memory-efficient LoRA linear computation. - - Forward: out = X @ W.T + bias + scale * (X @ A.T @ B.T) - - Memory optimizations: - - Only saves X, A, B for backward - - Uses in-place addmm_ operations - """ - - @staticmethod - def forward(ctx, X, W, bias, A, B, scale): # type: ignore[override] - orig_shape = X.shape - X_2d = X.view(-1, X.shape[-1]) if X.dim() == 3 else X - - out = torch.empty(X_2d.shape[0], W.shape[0], dtype=X.dtype, device=X.device) - torch.mm(X_2d, W.t(), out=out) - - if bias is not None: - out.add_(bias) - - out.addmm_(X_2d @ A.T, B.T, alpha=scale) - - if X.dim() == 3: - out = out.view(orig_shape[0], orig_shape[1], -1) - - ctx.custom_saved_tensors = (W, scale) - ctx.save_for_backward(A, B, X) - ctx.has_bias = bias is not None - return out - - @staticmethod - def backward(ctx, dY): # type: ignore[override] - W, scale = ctx.custom_saved_tensors - A, B, X = ctx.saved_tensors - - batch, seq_len, hd = X.shape - dY = dY.reshape(-1, dY.shape[-1]) - X = X.reshape(-1, X.shape[-1]) - A, B = A.t(), B.t() - - d_A = torch.empty_like(A) - d_B = torch.empty_like(B) - d_A.addmm_(X.t(), dY @ B.t(), alpha=scale, beta=0) - d_B.addmm_(A.t() @ X.t(), dY, alpha=scale, beta=0) - - dX = dY @ W - dX.addmm_(dY @ B.t(), A.t(), alpha=scale) - d_bias = dY.sum(dim=0) if ctx.has_bias else None - - return dX.view(batch, seq_len, hd), None, d_bias, d_A.t(), d_B.t(), None - - class LoRALinear(nn.Module): """LoRA linear layer. @@ -103,6 +50,7 @@ def __init__( self.out_dim = out_dim self.rank = rank self.alpha = alpha + self.scaling = alpha / rank self.use_bias = use_bias self.disabled = False @@ -155,17 +103,18 @@ def adapter_params(self) -> list[str]: return ["lora_a.weight", "lora_b.weight"] def forward(self, x: torch.Tensor) -> torch.Tensor: + # Base linear: out = x @ W.T + bias + out = F.linear(x, self.weight, self.bias) # type: ignore[arg-type] + if self.disabled: - return F.linear(x, self.weight, self.bias) # type: ignore[arg-type] - - return _LoRALinearFunction.apply( - self.dropout(x), - self.weight, - self.bias, - self.lora_a.weight, - self.lora_b.weight, - self.alpha / self.rank, - ) + return out + + # LoRA path: out += scale * (x @ A.T @ B.T) + x = self.dropout(x) + lora_out = self.lora_b(self.lora_a(x)) + out = out + self.scaling * lora_out + + return out def _lora_a_init_params(x: nn.Linear) -> None: From 702ab261adc953bb565e0abfb4217fe1e7a80328 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 22 Jan 2026 16:00:50 -0800 Subject: [PATCH 15/20] lint --- torchtitan/components/lora/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchtitan/components/lora/__init__.py b/torchtitan/components/lora/__init__.py index 0f60b90a4c..00fdde5ddb 100644 --- a/torchtitan/components/lora/__init__.py +++ b/torchtitan/components/lora/__init__.py @@ -4,10 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtitan.components.lora.lora import ( - LoRAConverter, - LoRALinear, -) +from torchtitan.components.lora.lora import LoRAConverter, LoRALinear __all__ = [ "LoRAConverter", From 861928eb9193b457e782e768f7b43289cb89fa65 Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 26 Jan 2026 14:33:18 -0800 Subject: [PATCH 16/20] move lora.py to comments, address comments --- torchtitan/__init__.py | 6 +- torchtitan/components/{lora => }/lora.py | 98 ++++++++---------------- torchtitan/components/lora/__init__.py | 12 --- torchtitan/protocols/model_converter.py | 4 + 4 files changed, 41 insertions(+), 79 deletions(-) rename torchtitan/components/{lora => }/lora.py (70%) delete mode 100644 torchtitan/components/lora/__init__.py diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py index efb781da17..8683c72492 100644 --- a/torchtitan/__init__.py +++ b/torchtitan/__init__.py @@ -6,12 +6,12 @@ from importlib.metadata import version -# Import to register lora module. -import torchtitan.components.lora # noqa: F401 - # Import to register quantization modules. import torchtitan.components.quantization # noqa: F401 +# Import to register lora module. +import torchtitan.components.lora # noqa: F401 + try: __version__ = version("torchtitan") except Exception as e: diff --git a/torchtitan/components/lora/lora.py b/torchtitan/components/lora.py similarity index 70% rename from torchtitan/components/lora/lora.py rename to torchtitan/components/lora.py index 74b174c30d..2dac402ad4 100644 --- a/torchtitan/components/lora/lora.py +++ b/torchtitan/components/lora.py @@ -30,7 +30,6 @@ class LoRALinear(nn.Module): rank: Rank of the low-rank approximation. alpha: Scaling factor. dropout: Dropout probability. - use_bias: Whether to include bias. """ def __init__( @@ -40,7 +39,6 @@ def __init__( rank: int, alpha: float, dropout: float = 0.0, - use_bias: bool = False, weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, dtype: Optional[torch.dtype] = None, @@ -51,8 +49,6 @@ def __init__( self.rank = rank self.alpha = alpha self.scaling = alpha / rank - self.use_bias = use_bias - self.disabled = False # Setup weight - reuse provided tensor or create on meta device if weight is not None: @@ -67,19 +63,16 @@ def __init__( ) # Setup bias - if use_bias: - if bias is not None: - self.register_parameter("bias", nn.Parameter(bias, requires_grad=False)) - else: - self.register_parameter( - "bias", - nn.Parameter( - torch.empty(out_dim, device="meta", dtype=dtype), - requires_grad=False, - ), - ) + if bias is not None: + self.register_parameter("bias", nn.Parameter(bias, requires_grad=False)) else: - self.register_parameter("bias", None) + self.register_parameter( + "bias", + nn.Parameter( + torch.empty(out_dim, device="meta", dtype=dtype), + requires_grad=False, + ), + ) self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() @@ -95,8 +88,8 @@ def to_empty( return self def initialize_parameters(self): - _lora_a_init_params(self.lora_a) - _lora_b_init_params(self.lora_b) + nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_b.weight) def adapter_params(self) -> list[str]: """Return names of LoRA adapter parameters.""" @@ -106,9 +99,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Base linear: out = x @ W.T + bias out = F.linear(x, self.weight, self.bias) # type: ignore[arg-type] - if self.disabled: - return out - # LoRA path: out += scale * (x @ A.T @ B.T) x = self.dropout(x) lora_out = self.lora_b(self.lora_a(x)) @@ -117,16 +107,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out -def _lora_a_init_params(x: nn.Linear) -> None: - """Initialize LoRA A weight to Kaiming uniform.""" - nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5)) - - -def _lora_b_init_params(x: nn.Linear) -> None: - """Initialize LoRA B weight to zeros.""" - nn.init.zeros_(x.weight) - - class LoRAConverter: """Model converter that adds LoRA adapters to Linear layers.""" @@ -134,7 +114,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.rank = job_config.lora.rank self.alpha = job_config.lora.alpha self.dropout = job_config.lora.dropout - self._converted_model: Optional[nn.Module] = None logger.info( f"LoRA training active with rank={self.rank}, alpha={self.alpha}, " @@ -143,39 +122,30 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): def convert(self, model: nn.Module) -> None: """Inplace conversion of the model to use LoRA adapters.""" - replacements = [] - for module in model.modules(): + num_replaced = 0 + for module in list(model.modules()): for child_name, child in module.named_children(): + # TODO: Add support for GroupedExperts. if isinstance(child, nn.Linear) and not isinstance(child, LoRALinear): - replacements.append((module, child_name, child)) - - for parent_module, child_name, child in replacements: - has_bias = child.bias is not None - original_weight = child.weight.data - original_bias = child.bias.data if has_bias else None - - # Break reference chain before creating new module - child.weight = None # type: ignore[assignment] - if has_bias: - child.bias = None # type: ignore[assignment] - - lora_linear = LoRALinear( - in_dim=child.in_features, - out_dim=child.out_features, - rank=self.rank, - alpha=self.alpha, - dropout=self.dropout, - use_bias=has_bias, - weight=original_weight, - bias=original_bias, - dtype=child.weight.dtype - if child.weight is not None - else original_weight.dtype, - ) - setattr(parent_module, child_name, lora_linear) - - self._set_lora_requires_grad(model) - self._converted_model = model + original_weight = child.weight.data + original_bias = child.bias.data if child.bias is not None else None + + # Break reference chain before creating new module + child.weight = None # type: ignore[assignment] + child.bias = None # type: ignore[assignment] + + lora_linear = LoRALinear( + in_dim=child.in_features, + out_dim=child.out_features, + rank=self.rank, + alpha=self.alpha, + dropout=self.dropout, + weight=original_weight, + bias=original_bias, + dtype=original_weight.dtype, + ) + setattr(module, child_name, lora_linear) + num_replaced += 1 # Wrap init_weights to also initialize LoRA parameters original_init_weights = model.init_weights @@ -189,7 +159,7 @@ def init_weights_with_lora(*args, **kwargs): total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info(f"Converted {len(replacements)} linear modules to LoRALinear") + logger.info(f"Converted {num_replaced} linear modules to LoRALinear") logger.info( f"Trainable params: {trainable_params:,} / {total_params:,} " f"({100 * trainable_params / total_params:.2f}%)" diff --git a/torchtitan/components/lora/__init__.py b/torchtitan/components/lora/__init__.py deleted file mode 100644 index 00fdde5ddb..0000000000 --- a/torchtitan/components/lora/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torchtitan.components.lora.lora import LoRAConverter, LoRALinear - -__all__ = [ - "LoRAConverter", - "LoRALinear", -] diff --git a/torchtitan/protocols/model_converter.py b/torchtitan/protocols/model_converter.py index cb4804be6f..39d46e938d 100644 --- a/torchtitan/protocols/model_converter.py +++ b/torchtitan/protocols/model_converter.py @@ -58,6 +58,10 @@ class ModelConvertersContainer(ModelConverter): """ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): + #if {"lora", "float8"}.issubset(job_config.model.converters): + # raise NotImplementedError( + # "LoRA is incompatible with FP8 linear subclass." + # ) converter_classes = [ _registry_model_converter_cls[name] for name in job_config.model.converters ] From b476b339b65f2a095acc4ff9b543fc56720ccd7a Mon Sep 17 00:00:00 2001 From: mori360 Date: Mon, 26 Jan 2026 14:38:47 -0800 Subject: [PATCH 17/20] lint --- torchtitan/__init__.py | 6 +++--- torchtitan/protocols/model_converter.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py index 8683c72492..efb781da17 100644 --- a/torchtitan/__init__.py +++ b/torchtitan/__init__.py @@ -6,12 +6,12 @@ from importlib.metadata import version -# Import to register quantization modules. -import torchtitan.components.quantization # noqa: F401 - # Import to register lora module. import torchtitan.components.lora # noqa: F401 +# Import to register quantization modules. +import torchtitan.components.quantization # noqa: F401 + try: __version__ = version("torchtitan") except Exception as e: diff --git a/torchtitan/protocols/model_converter.py b/torchtitan/protocols/model_converter.py index 39d46e938d..a44c14fb6e 100644 --- a/torchtitan/protocols/model_converter.py +++ b/torchtitan/protocols/model_converter.py @@ -58,10 +58,10 @@ class ModelConvertersContainer(ModelConverter): """ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): - #if {"lora", "float8"}.issubset(job_config.model.converters): - # raise NotImplementedError( - # "LoRA is incompatible with FP8 linear subclass." - # ) + if {"lora", "float8"}.issubset(job_config.model.converters): + raise NotImplementedError( + "LoRA is incompatible with FP8 linear subclass now." + ) converter_classes = [ _registry_model_converter_cls[name] for name in job_config.model.converters ] From cc249c65d8b182fc516e73062789775c9e86a3ec Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 27 Jan 2026 15:49:01 -0800 Subject: [PATCH 18/20] enable lora wrapper --- torchtitan/components/lora.py | 156 +++++++++++++++------------------- 1 file changed, 69 insertions(+), 87 deletions(-) diff --git a/torchtitan/components/lora.py b/torchtitan/components/lora.py index 2dac402ad4..358d981b77 100644 --- a/torchtitan/components/lora.py +++ b/torchtitan/components/lora.py @@ -18,15 +18,15 @@ class LoRALinear(nn.Module): - """LoRA linear layer. + """LoRA wrapper for any linear layer. - Implements: x -> W_0 @ x + (alpha / rank) * B @ A @ x + Wraps an existing linear layer and adds LoRA adapters. + Implements: x -> linear(x) + (alpha / rank) * B @ A @ x See: https://arxiv.org/abs/2106.09685 Args: - in_dim: Input dimension. - out_dim: Output dimension. + linear: The linear layer to wrap (nn.Linear, Float8Linear, etc.) rank: Rank of the low-rank approximation. alpha: Scaling factor. dropout: Dropout probability. @@ -34,51 +34,47 @@ class LoRALinear(nn.Module): def __init__( self, - in_dim: int, - out_dim: int, + linear: nn.Module, rank: int, alpha: float, dropout: float = 0.0, - weight: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, ): super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim + self.linear = linear + self.in_dim = linear.in_features + self.out_dim = linear.out_features self.rank = rank self.alpha = alpha self.scaling = alpha / rank - # Setup weight - reuse provided tensor or create on meta device - if weight is not None: - self.register_parameter("weight", nn.Parameter(weight, requires_grad=False)) - else: - self.register_parameter( - "weight", - nn.Parameter( - torch.empty(out_dim, in_dim, device="meta", dtype=dtype), - requires_grad=False, - ), - ) - - # Setup bias - if bias is not None: - self.register_parameter("bias", nn.Parameter(bias, requires_grad=False)) - else: - self.register_parameter( - "bias", - nn.Parameter( - torch.empty(out_dim, device="meta", dtype=dtype), - requires_grad=False, - ), - ) - self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() + # Get dtype from the linear layer's weight + dtype = linear.weight.dtype if hasattr(linear, 'weight') else None + # LoRA layers on meta device - self.lora_a = nn.Linear(in_dim, rank, bias=False, device="meta", dtype=dtype) - self.lora_b = nn.Linear(rank, out_dim, bias=False, device="meta", dtype=dtype) + self.lora_a = nn.Linear(self.in_dim, rank, bias=False, device="meta", dtype=dtype) + self.lora_b = nn.Linear(rank, self.out_dim, bias=False, device="meta", dtype=dtype) + + @property + def weight(self): + """Expose wrapped linear's weight for compatibility.""" + return self.linear.weight + + @property + def bias(self): + """Expose wrapped linear's bias for compatibility.""" + return self.linear.bias + + @property + def in_features(self): + """Expose wrapped linear's in_features for compatibility.""" + return self.in_dim + + @property + def out_features(self): + """Expose wrapped linear's out_features for compatibility.""" + return self.out_dim def to_empty( self, *, device: Optional[Union[str, torch.device, int]], recurse: bool = True @@ -96,8 +92,8 @@ def adapter_params(self) -> list[str]: return ["lora_a.weight", "lora_b.weight"] def forward(self, x: torch.Tensor) -> torch.Tensor: - # Base linear: out = x @ W.T + bias - out = F.linear(x, self.weight, self.bias) # type: ignore[arg-type] + # Base linear forward (works with nn.Linear, Float8Linear, etc.) + out = self.linear(x) # LoRA path: out += scale * (x @ A.T @ B.T) x = self.dropout(x) @@ -122,70 +118,56 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): def convert(self, model: nn.Module) -> None: """Inplace conversion of the model to use LoRA adapters.""" - num_replaced = 0 - for module in list(model.modules()): - for child_name, child in module.named_children(): - # TODO: Add support for GroupedExperts. - if isinstance(child, nn.Linear) and not isinstance(child, LoRALinear): - original_weight = child.weight.data - original_bias = child.bias.data if child.bias is not None else None + init_weights_fn = model.init_weights + lora_count = 0 + + def make_init_wrapper(prev_fn, ll: LoRALinear | None = None, final_log: bool = False): + def wrapped(*args, **kwargs): + if callable(prev_fn): + prev_fn(*args, **kwargs) + if ll is not None: + ll.initialize_parameters() + ll.lora_a.weight.requires_grad = True + ll.lora_b.weight.requires_grad = True + if final_log: + trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + logger.info( + f"LoRA parameters initialized for {lora_count} layers, " + f"trainable params: {trainable_params:,}" + ) + return wrapped - # Break reference chain before creating new module - child.weight = None # type: ignore[assignment] - child.bias = None # type: ignore[assignment] + for module in list(model.modules()): + for param in module._parameters.values(): + if param is not None: + param.requires_grad_(False) + for name, child in list(module._modules.items()): + if isinstance(child, nn.Linear) and not isinstance(child, LoRALinear): lora_linear = LoRALinear( - in_dim=child.in_features, - out_dim=child.out_features, + linear=child, rank=self.rank, alpha=self.alpha, dropout=self.dropout, - weight=original_weight, - bias=original_bias, - dtype=original_weight.dtype, ) - setattr(module, child_name, lora_linear) - num_replaced += 1 - - # Wrap init_weights to also initialize LoRA parameters - original_init_weights = model.init_weights - - def init_weights_with_lora(*args, **kwargs): - if callable(original_init_weights): - original_init_weights(*args, **kwargs) - self._init_lora_params(model) + setattr(module, name, lora_linear) + lora_count += 1 + init_weights_fn = make_init_wrapper(init_weights_fn, lora_linear) - object.__setattr__(model, "init_weights", init_weights_with_lora) + # Add final logging wrapper + init_weights_fn = make_init_wrapper(init_weights_fn, final_log=True) + object.__setattr__(model, "init_weights", init_weights_fn) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info(f"Converted {num_replaced} linear modules to LoRALinear") + logger.info(f"Converted {lora_count} linear modules to LoRALinear") logger.info( f"Trainable params: {trainable_params:,} / {total_params:,} " f"({100 * trainable_params / total_params:.2f}%)" ) - def _set_lora_requires_grad(self, model: nn.Module) -> None: - """Set requires_grad: True for LoRA params, False for others.""" - for name, param in model.named_parameters(): - param.requires_grad = "lora_a" in name or "lora_b" in name - - def _init_lora_params(self, model: nn.Module) -> None: - """Initialize LoRA parameters after model initialization.""" - lora_layer_count = 0 - for module in model.modules(): - if isinstance(module, LoRALinear): - module.initialize_parameters() - lora_layer_count += 1 - - self._set_lora_requires_grad(model) - - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info( - f"LoRA parameters initialized for {lora_layer_count} layers, " - f"trainable params: {trainable_params:,}" - ) - def post_optimizer_hook(self, model: Union[nn.Module, list[nn.Module]]) -> None: """Post-optimizer hook (no-op for LoRA).""" pass From 2773b1c7b8f59b1e39b5fec11779783d72dc46e4 Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 27 Jan 2026 15:50:10 -0800 Subject: [PATCH 19/20] remove warning --- torchtitan/protocols/model_converter.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchtitan/protocols/model_converter.py b/torchtitan/protocols/model_converter.py index a44c14fb6e..cb4804be6f 100644 --- a/torchtitan/protocols/model_converter.py +++ b/torchtitan/protocols/model_converter.py @@ -58,10 +58,6 @@ class ModelConvertersContainer(ModelConverter): """ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): - if {"lora", "float8"}.issubset(job_config.model.converters): - raise NotImplementedError( - "LoRA is incompatible with FP8 linear subclass now." - ) converter_classes = [ _registry_model_converter_cls[name] for name in job_config.model.converters ] From 452680278084410ee12ce6db295a7d91345c6f23 Mon Sep 17 00:00:00 2001 From: mori360 Date: Fri, 30 Jan 2026 15:49:27 -0800 Subject: [PATCH 20/20] enable tp --- torchtitan/components/lora.py | 108 +++++++-------- torchtitan/models/llama3/infra/parallelize.py | 125 +++++++++++++++--- 2 files changed, 160 insertions(+), 73 deletions(-) diff --git a/torchtitan/components/lora.py b/torchtitan/components/lora.py index 358d981b77..def9a1c6bd 100644 --- a/torchtitan/components/lora.py +++ b/torchtitan/components/lora.py @@ -41,8 +41,6 @@ def __init__( ): super().__init__() self.linear = linear - self.in_dim = linear.in_features - self.out_dim = linear.out_features self.rank = rank self.alpha = alpha self.scaling = alpha / rank @@ -53,8 +51,8 @@ def __init__( dtype = linear.weight.dtype if hasattr(linear, 'weight') else None # LoRA layers on meta device - self.lora_a = nn.Linear(self.in_dim, rank, bias=False, device="meta", dtype=dtype) - self.lora_b = nn.Linear(rank, self.out_dim, bias=False, device="meta", dtype=dtype) + self.lora_a = nn.Linear(linear.in_features, rank, bias=False, device="meta", dtype=dtype) + self.lora_b = nn.Linear(rank, linear.out_features, bias=False, device="meta", dtype=dtype) @property def weight(self): @@ -69,21 +67,15 @@ def bias(self): @property def in_features(self): """Expose wrapped linear's in_features for compatibility.""" - return self.in_dim + return self.linear.in_features @property def out_features(self): """Expose wrapped linear's out_features for compatibility.""" - return self.out_dim - - def to_empty( - self, *, device: Optional[Union[str, torch.device, int]], recurse: bool = True - ): - self.lora_a.to_empty(device=device, recurse=recurse) - self.lora_b.to_empty(device=device, recurse=recurse) - return self + return self.linear.out_features def initialize_parameters(self): + """Initialize LoRA parameters after materialization.""" nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) nn.init.zeros_(self.lora_b.weight) @@ -95,12 +87,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Base linear forward (works with nn.Linear, Float8Linear, etc.) out = self.linear(x) - # LoRA path: out += scale * (x @ A.T @ B.T) - x = self.dropout(x) - lora_out = self.lora_b(self.lora_a(x)) - out = out + self.scaling * lora_out + # LoRA path - use modules directly to preserve gradient flow through DTensor + lora_x = self.dropout(x) + lora_hidden = self.lora_a(lora_x) # [batch, seq, rank] + lora_out = self.lora_b(lora_hidden) # [batch, seq, out_features] - return out + # Both out and lora_out are plain tensors (use_local_output=True in TP layer_plan) + return out + self.scaling * lora_out class LoRAConverter: @@ -110,6 +103,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.rank = job_config.lora.rank self.alpha = job_config.lora.alpha self.dropout = job_config.lora.dropout + self._lora_modules: list[LoRALinear] = [] logger.info( f"LoRA training active with rank={self.rank}, alpha={self.alpha}, " @@ -118,34 +112,18 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): def convert(self, model: nn.Module) -> None: """Inplace conversion of the model to use LoRA adapters.""" - init_weights_fn = model.init_weights - lora_count = 0 - - def make_init_wrapper(prev_fn, ll: LoRALinear | None = None, final_log: bool = False): - def wrapped(*args, **kwargs): - if callable(prev_fn): - prev_fn(*args, **kwargs) - if ll is not None: - ll.initialize_parameters() - ll.lora_a.weight.requires_grad = True - ll.lora_b.weight.requires_grad = True - if final_log: - trainable_params = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - logger.info( - f"LoRA parameters initialized for {lora_count} layers, " - f"trainable params: {trainable_params:,}" - ) - return wrapped + self._apply_lora(model) + self._hook_init_weights(model) - for module in list(model.modules()): - for param in module._parameters.values(): - if param is not None: - param.requires_grad_(False) + logger.info(f"Converted {len(self._lora_modules)} linear modules to LoRALinear") + def _apply_lora(self, model: nn.Module) -> None: + """Replace Linear layers with LoRALinear wrappers.""" + for module in list(model.modules()): for name, child in list(module._modules.items()): if isinstance(child, nn.Linear) and not isinstance(child, LoRALinear): + if name == "output": + continue lora_linear = LoRALinear( linear=child, rank=self.rank, @@ -153,20 +131,42 @@ def wrapped(*args, **kwargs): dropout=self.dropout, ) setattr(module, name, lora_linear) - lora_count += 1 - init_weights_fn = make_init_wrapper(init_weights_fn, lora_linear) + self._lora_modules.append(lora_linear) + + def _hook_init_weights(self, model: nn.Module) -> None: + """Hook into init_weights to freeze base params and initialize LoRA.""" + original_init_weights = model.init_weights + lora_modules = self._lora_modules + model_ref = [model] + + def new_init_weights(*args, **kwargs): + if callable(original_init_weights): + original_init_weights(*args, **kwargs) + + for ll in lora_modules: + ll.initialize_parameters() + + m = model_ref[0] + + trainable_count = 0 + frozen_count = 0 + for name, param in m.named_parameters(): + if "lora_a" in name or "lora_b" in name: + param.requires_grad_(True) + trainable_count += 1 + else: + param.requires_grad_(False) + frozen_count += 1 - # Add final logging wrapper - init_weights_fn = make_init_wrapper(init_weights_fn, final_log=True) - object.__setattr__(model, "init_weights", init_weights_fn) + total_params = sum(p.numel() for p in m.parameters()) + trainable_params = sum(p.numel() for p in m.parameters() if p.requires_grad) + logger.info( + f"LoRA: frozen {frozen_count} params, trainable {trainable_count} params, " + f"trainable params: {trainable_params:,} / {total_params:,} " + f"({100 * trainable_params / total_params:.2f}%)" + ) - total_params = sum(p.numel() for p in model.parameters()) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info(f"Converted {lora_count} linear modules to LoRALinear") - logger.info( - f"Trainable params: {trainable_params:,} / {total_params:,} " - f"({100 * trainable_params / total_params:.2f}%)" - ) + object.__setattr__(model, "init_weights", new_init_weights) def post_optimizer_hook(self, model: Union[nn.Module, list[nn.Module]]) -> None: """Post-optimizer hook (no-op for LoRA).""" diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index b6c339fa67..980c626481 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -22,6 +22,7 @@ SequenceParallel, ) +from torchtitan.components.lora import LoRALinear from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import ParallelDims @@ -30,6 +31,28 @@ from torchtitan.tools.logging import logger +def _get_tp_path(module: nn.Module, path: str) -> str: + """Get TP path, redirecting to .linear for LoRALinear wrappers.""" + parts = path.split(".") + target = module + for part in parts: + target = getattr(target, part, None) + if target is None: + return path + return f"{path}.linear" if isinstance(target, LoRALinear) else path + + +def _is_lora_linear(module: nn.Module, path: str) -> bool: + """Check if module at path is LoRALinear.""" + parts = path.split(".") + target = module + for part in parts: + target = getattr(target, part, None) + if target is None: + return False + return isinstance(target, LoRALinear) + + # for selective op activation checkpointing _op_sac_save_list = { torch.ops.aten.mm.default, @@ -160,21 +183,27 @@ def apply_tp( # transformer block's inputs) # 2. Parallelize the root norm layer over the sequence dim # 3. Parallelize the final linear output layer + + # Handle output layer - redirect to .linear if wrapped with LoRALinear + output_path = _get_tp_path(model, "output") + + # Build plan for top-level modules + top_level_plan = { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + output_path: ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + } parallelize_module( model, tp_mesh, - { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ), - "norm": SequenceParallel(), - "output": ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Shard(-1) if loss_parallel else Replicate(), - use_local_output=not loss_parallel, - ), - }, + top_level_plan, ) # Parallel styles used for transformer block linear weights and their @@ -203,8 +232,40 @@ def apply_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + + # LoRA parallel styles for colwise and rowwise layers + # For colwise: lora_a uses Replicate, lora_b uses colwise_parallel + lora_a_colwise = RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Replicate(), + use_local_output=True, + ) + # For rowwise: lora_a takes sharded input, produces Replicate output (DTensor, no use_local_output) + # Then lora_b takes Replicate input and produces Shard(1) output + lora_a_rowwise = RowwiseParallel( + input_layouts=Shard(-1), + output_layouts=Replicate(), + use_local_output=False, # Keep as DTensor for lora_b to consume + ) + # lora_b for rowwise: ColwiseParallel with Shard(1) output to match base linear + lora_b_rowwise = ColwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + use_local_output=True, + ) + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): + # Build TP paths, redirecting to .linear submodule for LoRALinear wrappers + # This avoids TP trying to distribute LoRA adapter parameters + wq_path = _get_tp_path(transformer_block, "attention.wq") + wk_path = _get_tp_path(transformer_block, "attention.wk") + wv_path = _get_tp_path(transformer_block, "attention.wv") + wo_path = _get_tp_path(transformer_block, "attention.wo") + w1_path = _get_tp_path(transformer_block, "feed_forward.w1") + w2_path = _get_tp_path(transformer_block, "feed_forward.w2") + w3_path = _get_tp_path(transformer_block, "feed_forward.w3") + layer_plan = { "attention_norm": SequenceParallel(), # NOTE: when the fourth argument (positions) is not None, its input layout @@ -213,20 +274,46 @@ def apply_tp( input_layouts=(Shard(1), None, None, None), desired_input_layouts=(Replicate(), None, None, None), ), - "attention.wq": colwise_parallel(), - "attention.wk": colwise_parallel(), - "attention.wv": colwise_parallel(), - "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + wq_path: colwise_parallel(), + wk_path: colwise_parallel(), + wv_path: colwise_parallel(), + wo_path: rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), "feed_forward": prepare_module_input( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ), - "feed_forward.w1": colwise_parallel(), - "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), - "feed_forward.w3": colwise_parallel(), + w1_path: colwise_parallel(), + w2_path: rowwise_parallel(output_layouts=Shard(1)), + w3_path: colwise_parallel(), } + # Add LoRA parallelism to layer_plan if module is LoRALinear + # Colwise layers (wq, wk, wv, w1, w3): lora_b(lora_a(x)) behaves like colwise_parallel() + if _is_lora_linear(transformer_block, "attention.wq"): + layer_plan["attention.wq.lora_a"] = lora_a_colwise + layer_plan["attention.wq.lora_b"] = colwise_parallel() + if _is_lora_linear(transformer_block, "attention.wk"): + layer_plan["attention.wk.lora_a"] = lora_a_colwise + layer_plan["attention.wk.lora_b"] = colwise_parallel() + if _is_lora_linear(transformer_block, "attention.wv"): + layer_plan["attention.wv.lora_a"] = lora_a_colwise + layer_plan["attention.wv.lora_b"] = colwise_parallel() + if _is_lora_linear(transformer_block, "feed_forward.w1"): + layer_plan["feed_forward.w1.lora_a"] = lora_a_colwise + layer_plan["feed_forward.w1.lora_b"] = colwise_parallel() + if _is_lora_linear(transformer_block, "feed_forward.w3"): + layer_plan["feed_forward.w3.lora_a"] = lora_a_colwise + layer_plan["feed_forward.w3.lora_b"] = colwise_parallel() + + # Rowwise layers (wo, w2): lora_b(lora_a(x)) behaves like rowwise_parallel(output_layouts=Shard(1)) + if _is_lora_linear(transformer_block, "attention.wo"): + layer_plan["attention.wo.lora_a"] = lora_a_rowwise + layer_plan["attention.wo.lora_b"] = lora_b_rowwise + if _is_lora_linear(transformer_block, "feed_forward.w2"): + layer_plan["feed_forward.w2.lora_a"] = lora_a_rowwise + layer_plan["feed_forward.w2.lora_b"] = lora_b_rowwise + parallelize_module( # pyrefly: ignore [bad-argument-type] module=transformer_block,