Skip to content
Draft
3 changes: 3 additions & 0 deletions torchtitan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from importlib.metadata import version

# Import to register lora module.
import torchtitan.components.lora # noqa: F401

# Import to register quantization modules.
import torchtitan.components.quantization # noqa: F401

Expand Down
177 changes: 177 additions & 0 deletions torchtitan/components/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.protocols.model_converter import register_model_converter
from torchtitan.tools.logging import logger


class LoRALinear(nn.Module):
"""LoRA wrapper for any linear layer.

Wraps an existing linear layer and adds LoRA adapters.
Implements: x -> linear(x) + (alpha / rank) * B @ A @ x

See: https://arxiv.org/abs/2106.09685

Args:
linear: The linear layer to wrap (nn.Linear, Float8Linear, etc.)
rank: Rank of the low-rank approximation.
alpha: Scaling factor.
dropout: Dropout probability.
"""

def __init__(
self,
linear: nn.Module,
rank: int,
alpha: float,
dropout: float = 0.0,
):
super().__init__()
self.linear = linear
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank

self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()

# Get dtype from the linear layer's weight
dtype = linear.weight.dtype if hasattr(linear, 'weight') else None

# LoRA layers on meta device
self.lora_a = nn.Linear(linear.in_features, rank, bias=False, device="meta", dtype=dtype)
self.lora_b = nn.Linear(rank, linear.out_features, bias=False, device="meta", dtype=dtype)

@property
def weight(self):
"""Expose wrapped linear's weight for compatibility."""
return self.linear.weight

@property
def bias(self):
"""Expose wrapped linear's bias for compatibility."""
return self.linear.bias

@property
def in_features(self):
"""Expose wrapped linear's in_features for compatibility."""
return self.linear.in_features

@property
def out_features(self):
"""Expose wrapped linear's out_features for compatibility."""
return self.linear.out_features

def initialize_parameters(self):
"""Initialize LoRA parameters after materialization."""
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_b.weight)

def adapter_params(self) -> list[str]:
"""Return names of LoRA adapter parameters."""
return ["lora_a.weight", "lora_b.weight"]

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Base linear forward (works with nn.Linear, Float8Linear, etc.)
out = self.linear(x)

# LoRA path - use modules directly to preserve gradient flow through DTensor
lora_x = self.dropout(x)
lora_hidden = self.lora_a(lora_x) # [batch, seq, rank]
lora_out = self.lora_b(lora_hidden) # [batch, seq, out_features]

# Both out and lora_out are plain tensors (use_local_output=True in TP layer_plan)
return out + self.scaling * lora_out


class LoRAConverter:
"""Model converter that adds LoRA adapters to Linear layers."""

def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.rank = job_config.lora.rank
self.alpha = job_config.lora.alpha
self.dropout = job_config.lora.dropout
self._lora_modules: list[LoRALinear] = []

logger.info(
f"LoRA training active with rank={self.rank}, alpha={self.alpha}, "
f"dropout={self.dropout}"
)

def convert(self, model: nn.Module) -> None:
"""Inplace conversion of the model to use LoRA adapters."""
self._apply_lora(model)
self._hook_init_weights(model)

logger.info(f"Converted {len(self._lora_modules)} linear modules to LoRALinear")

def _apply_lora(self, model: nn.Module) -> None:
"""Replace Linear layers with LoRALinear wrappers."""
for module in list(model.modules()):
for name, child in list(module._modules.items()):
if isinstance(child, nn.Linear) and not isinstance(child, LoRALinear):
if name == "output":
continue
lora_linear = LoRALinear(
linear=child,
rank=self.rank,
alpha=self.alpha,
dropout=self.dropout,
)
setattr(module, name, lora_linear)
self._lora_modules.append(lora_linear)

def _hook_init_weights(self, model: nn.Module) -> None:
"""Hook into init_weights to freeze base params and initialize LoRA."""
original_init_weights = model.init_weights
lora_modules = self._lora_modules
model_ref = [model]

def new_init_weights(*args, **kwargs):
if callable(original_init_weights):
original_init_weights(*args, **kwargs)

for ll in lora_modules:
ll.initialize_parameters()

m = model_ref[0]

trainable_count = 0
frozen_count = 0
for name, param in m.named_parameters():
if "lora_a" in name or "lora_b" in name:
param.requires_grad_(True)
trainable_count += 1
else:
param.requires_grad_(False)
frozen_count += 1

total_params = sum(p.numel() for p in m.parameters())
trainable_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
logger.info(
f"LoRA: frozen {frozen_count} params, trainable {trainable_count} params, "
f"trainable params: {trainable_params:,} / {total_params:,} "
f"({100 * trainable_params / total_params:.2f}%)"
)

object.__setattr__(model, "init_weights", new_init_weights)

def post_optimizer_hook(self, model: Union[nn.Module, list[nn.Module]]) -> None:
"""Post-optimizer hook (no-op for LoRA)."""
pass


# Register the LoRA converter
register_model_converter(LoRAConverter, "lora")
16 changes: 16 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,21 @@ class ActivationCheckpoint:
"""


@dataclass
class LoRA:
rank: int = 8
"""Rank of the low-rank approximation. Higher rank = more parameters but better quality."""

alpha: float = 16.0
"""
Scaling factor for the low-rank approximation.
The LoRA output is scaled by (alpha / rank), so higher alpha means stronger LoRA effect.
"""

dropout: float = 0.0
"""Dropout probability applied to the LoRA layers. 0.0 means no dropout."""


@dataclass
class Compile:
enable: bool = False
Expand Down Expand Up @@ -991,6 +1006,7 @@ class JobConfig:
activation_checkpoint: ActivationCheckpoint = field(
default_factory=ActivationCheckpoint
)
lora: LoRA = field(default_factory=LoRA)
compile: Compile = field(default_factory=Compile)
quantize: Quantize = field(default_factory=Quantize)
comm: Comm = field(default_factory=Comm)
Expand Down
125 changes: 106 additions & 19 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
SequenceParallel,
)

from torchtitan.components.lora import LoRALinear
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.config.job_config import Compile as CompileConfig
from torchtitan.distributed import ParallelDims
Expand All @@ -30,6 +31,28 @@
from torchtitan.tools.logging import logger


def _get_tp_path(module: nn.Module, path: str) -> str:
"""Get TP path, redirecting to .linear for LoRALinear wrappers."""
parts = path.split(".")
target = module
for part in parts:
target = getattr(target, part, None)
if target is None:
return path
return f"{path}.linear" if isinstance(target, LoRALinear) else path


def _is_lora_linear(module: nn.Module, path: str) -> bool:
"""Check if module at path is LoRALinear."""
parts = path.split(".")
target = module
for part in parts:
target = getattr(target, part, None)
if target is None:
return False
return isinstance(target, LoRALinear)


# for selective op activation checkpointing
_op_sac_save_list = {
torch.ops.aten.mm.default,
Expand Down Expand Up @@ -160,21 +183,27 @@ def apply_tp(
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Parallelize the final linear output layer

# Handle output layer - redirect to .linear if wrapped with LoRALinear
output_path = _get_tp_path(model, "output")

# Build plan for top-level modules
top_level_plan = {
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
output_path: ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
}
parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
},
top_level_plan,
)

# Parallel styles used for transformer block linear weights and their
Expand Down Expand Up @@ -203,8 +232,40 @@ def apply_tp(
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
# by folding (and unfolding) the batch dimension and the sequence dimension.
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437

# LoRA parallel styles for colwise and rowwise layers
# For colwise: lora_a uses Replicate, lora_b uses colwise_parallel
lora_a_colwise = RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Replicate(),
use_local_output=True,
)
# For rowwise: lora_a takes sharded input, produces Replicate output (DTensor, no use_local_output)
# Then lora_b takes Replicate input and produces Shard(1) output
lora_a_rowwise = RowwiseParallel(
input_layouts=Shard(-1),
output_layouts=Replicate(),
use_local_output=False, # Keep as DTensor for lora_b to consume
)
# lora_b for rowwise: ColwiseParallel with Shard(1) output to match base linear
lora_b_rowwise = ColwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
use_local_output=True,
)

# pyrefly: ignore [not-callable]
for transformer_block in model.layers.values():
# Build TP paths, redirecting to .linear submodule for LoRALinear wrappers
# This avoids TP trying to distribute LoRA adapter parameters
wq_path = _get_tp_path(transformer_block, "attention.wq")
wk_path = _get_tp_path(transformer_block, "attention.wk")
wv_path = _get_tp_path(transformer_block, "attention.wv")
wo_path = _get_tp_path(transformer_block, "attention.wo")
w1_path = _get_tp_path(transformer_block, "feed_forward.w1")
w2_path = _get_tp_path(transformer_block, "feed_forward.w2")
w3_path = _get_tp_path(transformer_block, "feed_forward.w3")

layer_plan = {
"attention_norm": SequenceParallel(),
# NOTE: when the fourth argument (positions) is not None, its input layout
Expand All @@ -213,20 +274,46 @@ def apply_tp(
input_layouts=(Shard(1), None, None, None),
desired_input_layouts=(Replicate(), None, None, None),
),
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
"attention.wv": colwise_parallel(),
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
wq_path: colwise_parallel(),
wk_path: colwise_parallel(),
wv_path: colwise_parallel(),
wo_path: rowwise_parallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel(),
w1_path: colwise_parallel(),
w2_path: rowwise_parallel(output_layouts=Shard(1)),
w3_path: colwise_parallel(),
}

# Add LoRA parallelism to layer_plan if module is LoRALinear
# Colwise layers (wq, wk, wv, w1, w3): lora_b(lora_a(x)) behaves like colwise_parallel()
if _is_lora_linear(transformer_block, "attention.wq"):
layer_plan["attention.wq.lora_a"] = lora_a_colwise
layer_plan["attention.wq.lora_b"] = colwise_parallel()
if _is_lora_linear(transformer_block, "attention.wk"):
layer_plan["attention.wk.lora_a"] = lora_a_colwise
layer_plan["attention.wk.lora_b"] = colwise_parallel()
if _is_lora_linear(transformer_block, "attention.wv"):
layer_plan["attention.wv.lora_a"] = lora_a_colwise
layer_plan["attention.wv.lora_b"] = colwise_parallel()
if _is_lora_linear(transformer_block, "feed_forward.w1"):
layer_plan["feed_forward.w1.lora_a"] = lora_a_colwise
layer_plan["feed_forward.w1.lora_b"] = colwise_parallel()
if _is_lora_linear(transformer_block, "feed_forward.w3"):
layer_plan["feed_forward.w3.lora_a"] = lora_a_colwise
layer_plan["feed_forward.w3.lora_b"] = colwise_parallel()

# Rowwise layers (wo, w2): lora_b(lora_a(x)) behaves like rowwise_parallel(output_layouts=Shard(1))
if _is_lora_linear(transformer_block, "attention.wo"):
layer_plan["attention.wo.lora_a"] = lora_a_rowwise
layer_plan["attention.wo.lora_b"] = lora_b_rowwise
if _is_lora_linear(transformer_block, "feed_forward.w2"):
layer_plan["feed_forward.w2.lora_a"] = lora_a_rowwise
layer_plan["feed_forward.w2.lora_b"] = lora_b_rowwise

parallelize_module(
# pyrefly: ignore [bad-argument-type]
module=transformer_block,
Expand Down
Loading