diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py index 52c3ff3e22..efb781da17 100644 --- a/torchtitan/__init__.py +++ b/torchtitan/__init__.py @@ -6,6 +6,9 @@ from importlib.metadata import version +# Import to register lora module. +import torchtitan.components.lora # noqa: F401 + # Import to register quantization modules. import torchtitan.components.quantization # noqa: F401 diff --git a/torchtitan/components/lora.py b/torchtitan/components/lora.py new file mode 100644 index 0000000000..def9a1c6bd --- /dev/null +++ b/torchtitan/components/lora.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.protocols.model_converter import register_model_converter +from torchtitan.tools.logging import logger + + +class LoRALinear(nn.Module): + """LoRA wrapper for any linear layer. + + Wraps an existing linear layer and adds LoRA adapters. + Implements: x -> linear(x) + (alpha / rank) * B @ A @ x + + See: https://arxiv.org/abs/2106.09685 + + Args: + linear: The linear layer to wrap (nn.Linear, Float8Linear, etc.) + rank: Rank of the low-rank approximation. + alpha: Scaling factor. + dropout: Dropout probability. + """ + + def __init__( + self, + linear: nn.Module, + rank: int, + alpha: float, + dropout: float = 0.0, + ): + super().__init__() + self.linear = linear + self.rank = rank + self.alpha = alpha + self.scaling = alpha / rank + + self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() + + # Get dtype from the linear layer's weight + dtype = linear.weight.dtype if hasattr(linear, 'weight') else None + + # LoRA layers on meta device + self.lora_a = nn.Linear(linear.in_features, rank, bias=False, device="meta", dtype=dtype) + self.lora_b = nn.Linear(rank, linear.out_features, bias=False, device="meta", dtype=dtype) + + @property + def weight(self): + """Expose wrapped linear's weight for compatibility.""" + return self.linear.weight + + @property + def bias(self): + """Expose wrapped linear's bias for compatibility.""" + return self.linear.bias + + @property + def in_features(self): + """Expose wrapped linear's in_features for compatibility.""" + return self.linear.in_features + + @property + def out_features(self): + """Expose wrapped linear's out_features for compatibility.""" + return self.linear.out_features + + def initialize_parameters(self): + """Initialize LoRA parameters after materialization.""" + nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_b.weight) + + def adapter_params(self) -> list[str]: + """Return names of LoRA adapter parameters.""" + return ["lora_a.weight", "lora_b.weight"] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Base linear forward (works with nn.Linear, Float8Linear, etc.) + out = self.linear(x) + + # LoRA path - use modules directly to preserve gradient flow through DTensor + lora_x = self.dropout(x) + lora_hidden = self.lora_a(lora_x) # [batch, seq, rank] + lora_out = self.lora_b(lora_hidden) # [batch, seq, out_features] + + # Both out and lora_out are plain tensors (use_local_output=True in TP layer_plan) + return out + self.scaling * lora_out + + +class LoRAConverter: + """Model converter that adds LoRA adapters to Linear layers.""" + + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): + self.rank = job_config.lora.rank + self.alpha = job_config.lora.alpha + self.dropout = job_config.lora.dropout + self._lora_modules: list[LoRALinear] = [] + + logger.info( + f"LoRA training active with rank={self.rank}, alpha={self.alpha}, " + f"dropout={self.dropout}" + ) + + def convert(self, model: nn.Module) -> None: + """Inplace conversion of the model to use LoRA adapters.""" + self._apply_lora(model) + self._hook_init_weights(model) + + logger.info(f"Converted {len(self._lora_modules)} linear modules to LoRALinear") + + def _apply_lora(self, model: nn.Module) -> None: + """Replace Linear layers with LoRALinear wrappers.""" + for module in list(model.modules()): + for name, child in list(module._modules.items()): + if isinstance(child, nn.Linear) and not isinstance(child, LoRALinear): + if name == "output": + continue + lora_linear = LoRALinear( + linear=child, + rank=self.rank, + alpha=self.alpha, + dropout=self.dropout, + ) + setattr(module, name, lora_linear) + self._lora_modules.append(lora_linear) + + def _hook_init_weights(self, model: nn.Module) -> None: + """Hook into init_weights to freeze base params and initialize LoRA.""" + original_init_weights = model.init_weights + lora_modules = self._lora_modules + model_ref = [model] + + def new_init_weights(*args, **kwargs): + if callable(original_init_weights): + original_init_weights(*args, **kwargs) + + for ll in lora_modules: + ll.initialize_parameters() + + m = model_ref[0] + + trainable_count = 0 + frozen_count = 0 + for name, param in m.named_parameters(): + if "lora_a" in name or "lora_b" in name: + param.requires_grad_(True) + trainable_count += 1 + else: + param.requires_grad_(False) + frozen_count += 1 + + total_params = sum(p.numel() for p in m.parameters()) + trainable_params = sum(p.numel() for p in m.parameters() if p.requires_grad) + logger.info( + f"LoRA: frozen {frozen_count} params, trainable {trainable_count} params, " + f"trainable params: {trainable_params:,} / {total_params:,} " + f"({100 * trainable_params / total_params:.2f}%)" + ) + + object.__setattr__(model, "init_weights", new_init_weights) + + def post_optimizer_hook(self, model: Union[nn.Module, list[nn.Module]]) -> None: + """Post-optimizer hook (no-op for LoRA).""" + pass + + +# Register the LoRA converter +register_model_converter(LoRAConverter, "lora") diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 108c38efba..ec89ae39a4 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -694,6 +694,21 @@ class ActivationCheckpoint: """ +@dataclass +class LoRA: + rank: int = 8 + """Rank of the low-rank approximation. Higher rank = more parameters but better quality.""" + + alpha: float = 16.0 + """ + Scaling factor for the low-rank approximation. + The LoRA output is scaled by (alpha / rank), so higher alpha means stronger LoRA effect. + """ + + dropout: float = 0.0 + """Dropout probability applied to the LoRA layers. 0.0 means no dropout.""" + + @dataclass class Compile: enable: bool = False @@ -991,6 +1006,7 @@ class JobConfig: activation_checkpoint: ActivationCheckpoint = field( default_factory=ActivationCheckpoint ) + lora: LoRA = field(default_factory=LoRA) compile: Compile = field(default_factory=Compile) quantize: Quantize = field(default_factory=Quantize) comm: Comm = field(default_factory=Comm) diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index b6c339fa67..980c626481 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -22,6 +22,7 @@ SequenceParallel, ) +from torchtitan.components.lora import LoRALinear from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import ParallelDims @@ -30,6 +31,28 @@ from torchtitan.tools.logging import logger +def _get_tp_path(module: nn.Module, path: str) -> str: + """Get TP path, redirecting to .linear for LoRALinear wrappers.""" + parts = path.split(".") + target = module + for part in parts: + target = getattr(target, part, None) + if target is None: + return path + return f"{path}.linear" if isinstance(target, LoRALinear) else path + + +def _is_lora_linear(module: nn.Module, path: str) -> bool: + """Check if module at path is LoRALinear.""" + parts = path.split(".") + target = module + for part in parts: + target = getattr(target, part, None) + if target is None: + return False + return isinstance(target, LoRALinear) + + # for selective op activation checkpointing _op_sac_save_list = { torch.ops.aten.mm.default, @@ -160,21 +183,27 @@ def apply_tp( # transformer block's inputs) # 2. Parallelize the root norm layer over the sequence dim # 3. Parallelize the final linear output layer + + # Handle output layer - redirect to .linear if wrapped with LoRALinear + output_path = _get_tp_path(model, "output") + + # Build plan for top-level modules + top_level_plan = { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + output_path: ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + } parallelize_module( model, tp_mesh, - { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ), - "norm": SequenceParallel(), - "output": ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Shard(-1) if loss_parallel else Replicate(), - use_local_output=not loss_parallel, - ), - }, + top_level_plan, ) # Parallel styles used for transformer block linear weights and their @@ -203,8 +232,40 @@ def apply_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + + # LoRA parallel styles for colwise and rowwise layers + # For colwise: lora_a uses Replicate, lora_b uses colwise_parallel + lora_a_colwise = RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Replicate(), + use_local_output=True, + ) + # For rowwise: lora_a takes sharded input, produces Replicate output (DTensor, no use_local_output) + # Then lora_b takes Replicate input and produces Shard(1) output + lora_a_rowwise = RowwiseParallel( + input_layouts=Shard(-1), + output_layouts=Replicate(), + use_local_output=False, # Keep as DTensor for lora_b to consume + ) + # lora_b for rowwise: ColwiseParallel with Shard(1) output to match base linear + lora_b_rowwise = ColwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + use_local_output=True, + ) + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): + # Build TP paths, redirecting to .linear submodule for LoRALinear wrappers + # This avoids TP trying to distribute LoRA adapter parameters + wq_path = _get_tp_path(transformer_block, "attention.wq") + wk_path = _get_tp_path(transformer_block, "attention.wk") + wv_path = _get_tp_path(transformer_block, "attention.wv") + wo_path = _get_tp_path(transformer_block, "attention.wo") + w1_path = _get_tp_path(transformer_block, "feed_forward.w1") + w2_path = _get_tp_path(transformer_block, "feed_forward.w2") + w3_path = _get_tp_path(transformer_block, "feed_forward.w3") + layer_plan = { "attention_norm": SequenceParallel(), # NOTE: when the fourth argument (positions) is not None, its input layout @@ -213,20 +274,46 @@ def apply_tp( input_layouts=(Shard(1), None, None, None), desired_input_layouts=(Replicate(), None, None, None), ), - "attention.wq": colwise_parallel(), - "attention.wk": colwise_parallel(), - "attention.wv": colwise_parallel(), - "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + wq_path: colwise_parallel(), + wk_path: colwise_parallel(), + wv_path: colwise_parallel(), + wo_path: rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), "feed_forward": prepare_module_input( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), ), - "feed_forward.w1": colwise_parallel(), - "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), - "feed_forward.w3": colwise_parallel(), + w1_path: colwise_parallel(), + w2_path: rowwise_parallel(output_layouts=Shard(1)), + w3_path: colwise_parallel(), } + # Add LoRA parallelism to layer_plan if module is LoRALinear + # Colwise layers (wq, wk, wv, w1, w3): lora_b(lora_a(x)) behaves like colwise_parallel() + if _is_lora_linear(transformer_block, "attention.wq"): + layer_plan["attention.wq.lora_a"] = lora_a_colwise + layer_plan["attention.wq.lora_b"] = colwise_parallel() + if _is_lora_linear(transformer_block, "attention.wk"): + layer_plan["attention.wk.lora_a"] = lora_a_colwise + layer_plan["attention.wk.lora_b"] = colwise_parallel() + if _is_lora_linear(transformer_block, "attention.wv"): + layer_plan["attention.wv.lora_a"] = lora_a_colwise + layer_plan["attention.wv.lora_b"] = colwise_parallel() + if _is_lora_linear(transformer_block, "feed_forward.w1"): + layer_plan["feed_forward.w1.lora_a"] = lora_a_colwise + layer_plan["feed_forward.w1.lora_b"] = colwise_parallel() + if _is_lora_linear(transformer_block, "feed_forward.w3"): + layer_plan["feed_forward.w3.lora_a"] = lora_a_colwise + layer_plan["feed_forward.w3.lora_b"] = colwise_parallel() + + # Rowwise layers (wo, w2): lora_b(lora_a(x)) behaves like rowwise_parallel(output_layouts=Shard(1)) + if _is_lora_linear(transformer_block, "attention.wo"): + layer_plan["attention.wo.lora_a"] = lora_a_rowwise + layer_plan["attention.wo.lora_b"] = lora_b_rowwise + if _is_lora_linear(transformer_block, "feed_forward.w2"): + layer_plan["feed_forward.w2.lora_a"] = lora_a_rowwise + layer_plan["feed_forward.w2.lora_b"] = lora_b_rowwise + parallelize_module( # pyrefly: ignore [bad-argument-type] module=transformer_block,