diff --git a/docs/guides/sft.md b/docs/guides/sft.md index 982052f074..e4e4915626 100644 --- a/docs/guides/sft.md +++ b/docs/guides/sft.md @@ -161,3 +161,45 @@ As long as your custom dataset has the `formatted_ds` and `task_spec` attributes ## Evaluate the Trained Model Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities. + + +## LoRA Configuration + +NeMo RL supports LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning. LoRA reduces trainable parameters by using low-rank matrices for weight updates while keeping the base model frozen. + +### Configuration Parameters + +The LoRA configuration is specified under the `policy.lora_cfg` section: + +policy: + lora_cfg: + enabled: False # Set to True to enable LoRA fine-tuning + target_modules: [] # List of module names to apply LoRA + exclude_modules: [] # List of module names to exclude from LoRA + match_all_linear: true # Apply LoRA to all linear layers + dim: 8 # LoRA rank (r): controls adaptation capacity + alpha: 32 # LoRA scaling factor (effective lr = alpha/dim) + dropout: 0.0 # Dropout probability for LoRA layers + dropout_position: "post" # Dropout position: "pre" or "post" + lora_A_init: "xavier" # Initialization method: "xavier" or "uniform" + use_triton: true # Use Triton-optimized kernels + +### Parameter Details +- **`enabled`** (bool): Whether to enable LoRA training +- **`target_modules`** (list): Specific module names to apply LoRA. Empty with `match_all_linear=true` applies to all linear layers +- **`exclude_modules`** (list): Module names to exclude from LoRA +- **`match_all_linear`** (bool): When `true`, applies LoRA to all linear layers (overrides `target_modules`) +- **`dim`** (int): LoRA rank (r). Lower values = fewer parameters but less capacity. Typical: 4, 8, 16, 32, 64 +- **`alpha`** (int): LoRA scaling factor. Effective learning rate multiplier = `alpha/dim`. Typical: 16, 32, 64 +- **`dropout`** (float): Dropout probability for regularization +- **`dropout_position`** (str): Apply dropout before ("pre") or after ("post") LoRA +- **`lora_A_init`** (str): Initialization method for LoRA A matrix +- **`use_triton`** (bool): Use Triton-optimized kernels for better performance. Used for dtensor v2 only. **Note**: [Automodel not support triton for TP > 1](https://github.com/NVIDIA-NeMo/Automodel/blob/b2db55eee98dfe81a8bfe5e23ac4e57afd8ab261/nemo_automodel/recipes/llm/train_ft.py#L199). Set to `false` when `tensor_parallel_size > 1` to avoid compatibility issues + +### Example Usage + +```bash +uv run examples/run_sft.py policy.lora_cfg.enabled=True +``` + +For more details on LoRA, see [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml index 77ff8aac89..82950cf153 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml @@ -8,6 +8,9 @@ policy: tokenizer: name: meta-llama/Llama-3.2-1B make_sequence_length_divisible_by: 1 + lora_cfg: + enabled: true + dim: 32 data: dataset_name: openmathinstruct2 prompt_file: examples/prompts/math.txt diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 05b299a34e..9316ecccba 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -36,6 +36,7 @@ policy: offload_optimizer_for_logprob: false dtensor_cfg: + _v2: true enabled: true env_vars: {} cpu_offload: False @@ -44,6 +45,19 @@ policy: tensor_parallel_size: 1 context_parallel_size: 1 custom_parallel_plan: null + + # LoRA (Low-Rank Adaptation) Configuration + lora_cfg: + enabled: False # Set to True to enable LoRA fine-tuning + target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers) + exclude_modules: [] # List of module names to exclude from LoRA + match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules) + dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64 + alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64 + dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout) + dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA) + lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform" + use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1 dynamic_batching: enabled: false diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index dad82594d1..aa3dad5203 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -131,6 +131,23 @@ class MegatronConfig(TypedDict): distributed_data_parallel_config: MegatronDDPConfig +class LoRAConfigDisabled(TypedDict): + enabled: Literal[False] + + +class LoRAConfig(TypedDict): + enabled: Literal[True] + target_modules: list[str] + exclude_modules: list[str] + match_all_linear: bool + dim: int + alpha: int + dropout: float + dropout_position: Literal["pre", "post"] + lora_A_init: str + use_triton: bool + + class TokenizerConfig(TypedDict): name: str chat_template: NotRequired[str] @@ -189,6 +206,7 @@ class PolicyConfig(TypedDict): reward_model_cfg: NotRequired[RewardModelConfig] dtensor_cfg: DTensorConfig | DTensorConfigDisabled megatron_cfg: NotRequired[MegatronConfig | MegatronConfigDisabled] + lora_cfg: NotRequired[LoRAConfig | LoRAConfigDisabled] hf_config_overrides: NotRequired[dict[str, Any]] dynamic_batching: DynamicBatchingConfig | DynamicBatchingConfigDisabled sequence_packing: NotRequired[SequencePackingConfig | SequencePackingConfigDisabled] diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 434f850423..4310e630d7 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -112,7 +112,12 @@ def __init__( if use_v2: worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" else: - worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker" + assert config.get("lora_cfg", {}).get("enabled", False) is False, ( + "LoRA is not supported for DTensorPolicyWorker V1" + ) + worker_builder_cls = ( + "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker" + ) tp_size = config["dtensor_cfg"]["tensor_parallel_size"] cp_size = config["dtensor_cfg"]["context_parallel_size"] diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 4b8bf56d42..4b1202208c 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -14,18 +14,24 @@ import gc import itertools +import math import os import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import Any, Generator, Optional, cast +import nemo_automodel.components._peft.lora as _lora_mod import ray import torch from accelerate import init_empty_weights from nemo_automodel import ( NeMoAutoModelForSequenceClassification, ) +from nemo_automodel.components._peft.lora import ( + PeftConfig, + apply_lora_to_linear_modules, +) from nemo_automodel.components.distributed.cp_utils import ( create_context_parallel_ctx, get_train_context, @@ -93,6 +99,15 @@ from nemo_rl.utils.packed_tensor import packed_broadcast_producer +# TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc (https://github.com/NVIDIA-NeMo/RL/issues/1586) +def _patched_init_lora_weights(self, init_method: str): + if init_method == "xavier": + nn.init.xavier_normal_(self.lora_A.weight.data) + else: + nn.init.kaiming_uniform_(self.lora_A.weight.data, a=math.sqrt(5)) + self.lora_B.weight.data.zero_() + + @ray.remote( runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2") ) # pragma: no cover @@ -222,6 +237,23 @@ def __init__( full_state_dict = None model_state_dict_keys = None + + # lora config + lora_cfg = self.cfg.get("lora_cfg", None) + self.peft_config = None + self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"] + # patch the init_lora_weights method to use the xavier initialization + _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights + if self.lora_enabled: + if self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1: + assert not lora_cfg["use_triton"], ( + "Triton is not supported when tensor_parallel_size > 1" + ) + # Always use float32 since FSDP requires all parameters to be in the same dtype. + # autocast should cast the weights to the correct dtype during the forward pass. + cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"} + self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype) + if self.rank == 0: print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") model = model_class.from_pretrained( @@ -233,6 +265,9 @@ def __init__( torch_dtype=str(model_config.torch_dtype), ) + if self.peft_config is not None: + apply_lora_to_linear_modules(model, self.peft_config) + full_state_dict = model.state_dict() # Store the original model state dict keys before any parallelization model_state_dict_keys = list(full_state_dict.keys()) @@ -255,6 +290,8 @@ def __init__( trust_remote_code=True, torch_dtype=str(model_config.torch_dtype), ) + if self.lora_enabled: + apply_lora_to_linear_modules(self.model, self.peft_config) if self.model.config.pad_token_id is None: self.model.config.pad_token_id = tokenizer.pad_token_id @@ -1857,6 +1894,9 @@ def save_checkpoint( "peft_config", } } + if self.lora_enabled: + checkpoint_kwargs["is_peft"] = True + checkpoint_kwargs["peft_config"] = self.peft_config save_checkpoint( model=self.model, diff --git a/tests/functional/test_automodel_lora_sft.sh b/tests/functional/test_automodel_lora_sft.sh new file mode 100644 index 0000000000..b2baf88170 --- /dev/null +++ b/tests/functional/test_automodel_lora_sft.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# clean up checkpoint directory on exit +trap "rm -rf /tmp/lora_sft_checkpoints" EXIT + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_sft.py \ + policy.model_name=Qwen/Qwen3-0.6B \ + cluster.gpus_per_node=2 \ + sft.max_num_steps=3 \ + sft.val_batches=1 \ + sft.val_period=3 \ + policy.dtensor_cfg.lora.enabled=true \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=true \ + checkpointing.save_period=3 \ + checkpointing.checkpoint_dir=/tmp/lora_sft_checkpoints \ + "$@" \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["3"] < 5.9' + diff --git a/tests/unit/models/dtensor/test_lora.py b/tests/unit/models/dtensor/test_lora.py new file mode 100644 index 0000000000..eac60eb6fe --- /dev/null +++ b/tests/unit/models/dtensor/test_lora.py @@ -0,0 +1,333 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math + +import pytest + +# Skip entire module if nemo_automodel is not available +pytest_plugins = [] +try: + import nemo_automodel # noqa: F401 +except ImportError: + pytest.skip("nemo_automodel not available", allow_module_level=True) + +import torch +import torch.nn as nn +from nemo_automodel.components._peft.lora import ( + LinearLoRA, + PeftConfig, + apply_lora_to_linear_modules, +) + +from nemo_rl.models.policy.workers.dtensor_policy_worker_v2 import ( + _patched_init_lora_weights, +) + + +class SimpleLoraMock(nn.Module): + """Simple mock LoRA module for testing initialization.""" + + def __init__(self, in_features=128, out_features=256, lora_dim=8): + super().__init__() + self.lora_A = nn.Linear(in_features, lora_dim, bias=False) + self.lora_B = nn.Linear(lora_dim, out_features, bias=False) + + +@pytest.mark.parametrize("init_method", ["xavier"]) +def test_lora_init_differs_from_upstream_buggy_version(init_method): + """ + Test that our patched LoRA initialization differs from the buggy upstream version. + + Remove this test once Automodel is bumped to commit 2d20e33a19d5e53a271b1403b507475e68ad14dc or later. + + Issue: https://github.com/NVIDIA-NeMo/RL/issues/1586 + """ + torch.manual_seed(42) + + # Create two identical LoRA modules + lora_buggy = LinearLoRA(nn.Linear(16, 16)) + lora_patched = LinearLoRA(nn.Linear(16, 16)) + + # Copy initial weights to ensure identical starting point + lora_patched.lora_A.weight.data.copy_(lora_buggy.lora_A.weight.data) + lora_patched.lora_B.weight.data.copy_(lora_buggy.lora_B.weight.data) + + # Apply buggy upstream initialization + torch.manual_seed(42) + lora_buggy.init_lora_weights(init_method) + + # Apply our patched initialization + torch.manual_seed(42) + _patched_init_lora_weights(lora_patched, init_method) + + # For xavier method, they should differ (that's the bug) + + # Assert that weights differ due to the upstream bug + are_equal_A = torch.allclose( + lora_buggy.lora_A.weight.data, + lora_patched.lora_A.weight.data, + atol=1e-6, + rtol=1e-6, + ) + + assert not are_equal_A, ( + "LoRA A weights should differ for xavier initialization. " + "If this assertion fails, the upstream bug has been fixed in Automodel. " + "You can:\n" + "1. Remove the patch in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py\n" + "2. Remove the patching call\n" + "3. Close issue: https://github.com/NVIDIA-NeMo/RL/issues/1586\n" + "4. Delete this test" + ) + + # LoRA B should always be zero-initialized (both implementations do this correctly) + are_equal_B = torch.allclose( + lora_buggy.lora_B.weight.data, + lora_patched.lora_B.weight.data, + atol=0, + rtol=0, + ) + assert are_equal_B, "LoRA B weights should both be zero" + assert torch.all(lora_buggy.lora_B.weight.data == 0), ( + "LoRA B should be zero-initialized" + ) + assert torch.all(lora_patched.lora_B.weight.data == 0), ( + "LoRA B should be zero-initialized" + ) + + +def test_lora_init_statistical_properties(): + """ + Additional test to verify the statistical properties of the patched initialization. + This ensures our fix produces reasonable weight distributions. + """ + torch.manual_seed(42) + + lora = SimpleLoraMock(in_features=512, out_features=1024, lora_dim=32) + + # Test xavier initialization + _patched_init_lora_weights(lora, "xavier") + + # Xavier normal should have mean ≈ 0 and specific std + mean_A = lora.lora_A.weight.data.mean().item() + std_A = lora.lora_A.weight.data.std().item() + + assert abs(mean_A) < 0.1, f"Xavier normal should have mean ≈ 0, got {mean_A}" + # Xavier normal std = sqrt(2 / (fan_in + fan_out)) + expected_std = math.sqrt(2.0 / (512 + 32)) + assert abs(std_A - expected_std) < 0.05, ( + f"Xavier normal std should be ≈ {expected_std}, got {std_A}" + ) + + # LoRA B should be all zeros + assert torch.all(lora.lora_B.weight.data == 0), "LoRA B should be zero-initialized" + + # Test kaiming initialization + lora2 = SimpleLoraMock(in_features=512, out_features=1024, lora_dim=32) + _patched_init_lora_weights(lora2, "kaiming") + + mean_A2 = lora2.lora_A.weight.data.mean().item() + assert abs(mean_A2) < 0.1, f"Kaiming should have mean ≈ 0, got {mean_A2}" + assert torch.all(lora2.lora_B.weight.data == 0), "LoRA B should be zero-initialized" + + +class DummyModel(nn.Module): + """A dummy neural network model with two linear layers used for testing LoRA injection.""" + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(16, 16) + self.linear2 = nn.Linear(16, 16) + self.config = {} + + def forward(self, x): + """Forward pass through two linear layers with ReLU activation in between.""" + x = self.linear1(x).relu() + x = self.linear2(x) + return x + + +class DummyModelNoConfig(nn.Module): + """Same as DummyModel but without a `config` attribute.""" + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(16, 16) + self.linear2 = nn.Linear(16, 16) + + def forward(self, x): + x = self.linear1(x).relu() + x = self.linear2(x) + return x + + +@pytest.fixture +def dummy_input(): + """Provides a dummy input tensor for model testing.""" + return torch.randn(2, 16, requires_grad=True) + + +@pytest.fixture +def model(): + """Instantiates and returns a DummyModel instance.""" + return DummyModel() + + +@pytest.fixture +def model_no_config(): + """Instantiates a model that has no `config` attr.""" + return DummyModelNoConfig() + + +def test_lora_patch_applies_to_selected_module(model): + """Tests that LoRA is only applied to specified target modules.""" + apply_lora_to_linear_modules( + model, PeftConfig(target_modules=["linear1"], dim=4, alpha=8) + ) + assert isinstance(model.linear1, LinearLoRA) + assert not isinstance(model.linear2, LinearLoRA) + + +def test_lora_patch_on_model_without_config(model_no_config): + """LoRA should still patch correctly even if the model lacks `config`.""" + apply_lora_to_linear_modules( + model_no_config, PeftConfig(target_modules=["linear1"], dim=4, alpha=8) + ) + assert isinstance(model_no_config.linear1, LinearLoRA) + assert not isinstance(model_no_config.linear2, LinearLoRA) + + +def test_lora_layers_are_trainable(): + """Ensures that LoRA layers are trainable while base weights remain frozen.""" + base = nn.Linear(16, 16) + lora = LinearLoRA(base, dim=4, alpha=8) + + assert lora.weight.requires_grad is False + assert lora.lora_A.weight.requires_grad + assert lora.lora_B.weight.requires_grad + if lora.bias is not None: + assert lora.bias.requires_grad is False + + +def test_forward_output_consistency(dummy_input): + """Verifies that model output shape remains the same after LoRA patching, + but values change due to the added LoRA components. + """ + base = DummyModel() + model = DummyModel() + apply_lora_to_linear_modules( + model, PeftConfig(target_modules=["linear1"], dim=4, alpha=8) + ) + + base.eval() + model.eval() + + with torch.no_grad(): + out1 = base(dummy_input) + out2 = model(dummy_input) + + assert out1.shape == out2.shape + assert not torch.allclose(out1, out2), "Output should differ due to LoRA injection" + + +def test_backward_pass(dummy_input): + """Checks that backpropagation works and gradients are correctly computed + when LoRA is applied. + """ + model = DummyModel() + apply_lora_to_linear_modules( + model, PeftConfig(target_modules=["linear1"], dim=4, alpha=8) + ) + output = model(dummy_input) + loss = output.sum() + loss.backward() + + grads = [p.grad for p in model.parameters() if p.requires_grad] + assert any(g is not None for g in grads), "Some parameters should receive gradients" + assert all(torch.isfinite(g).all() for g in grads if g is not None), ( + "Gradients should be finite" + ) + + +def test_backward_pass_without_config(dummy_input, model_no_config): + """Backward pass must succeed on a model without `config`.""" + apply_lora_to_linear_modules( + model_no_config, PeftConfig(target_modules=["linear1"], dim=4, alpha=8) + ) + out = model_no_config(dummy_input) + loss = out.sum() + loss.backward() + + grads = [p.grad for p in model_no_config.parameters() if p.requires_grad] + assert any(g is not None for g in grads) + assert all(torch.isfinite(g).all() for g in grads if g is not None) + + +def test_apply_lora_respects_wildcard(model): + """Validates that wildcard matching correctly applies LoRA to all matching modules.""" + apply_lora_to_linear_modules( + model, PeftConfig(target_modules=[".*"], dim=4, alpha=8) + ) + assert isinstance(model.linear1, LinearLoRA) + assert isinstance(model.linear2, LinearLoRA) + + +def test_no_patch_on_non_matching_module(model): + """Confirms that no modules are patched if target pattern doesn't match any names.""" + apply_lora_to_linear_modules( + model, PeftConfig(target_modules=["nonexistent_module"], dim=4, alpha=8) + ) + assert not isinstance(model.linear1, LinearLoRA) + assert not isinstance(model.linear2, LinearLoRA) + + +def test_lora_patch_with_dtype_string(model): + """Tests that LoRA can be applied with dtype specified as string.""" + apply_lora_to_linear_modules( + model, + PeftConfig( + target_modules=["linear1"], dim=4, alpha=8, lora_dtype="torch.bfloat16" + ), + ) + assert isinstance(model.linear1, LinearLoRA) + assert model.linear1.lora_A.weight.dtype == torch.bfloat16 + assert model.linear1.lora_B.weight.dtype == torch.bfloat16 + assert not isinstance(model.linear2, LinearLoRA) + + +def test_dropout_pre_post_effects(dummy_input): + """Tests that different dropout positions ('pre' vs 'post') lead to different outputs.""" + base = nn.Linear(16, 16) + lora_pre = LinearLoRA(base, dim=4, alpha=8, dropout=0.5, dropout_position="pre") + lora_post = LinearLoRA(base, dim=4, alpha=8, dropout=0.5, dropout_position="post") + + with torch.no_grad(): + lora_pre.lora_A.weight.uniform_() + lora_pre.lora_B.weight.uniform_() + + lora_post.lora_A.weight.copy_(lora_pre.lora_A.weight) + lora_post.lora_B.weight.copy_(lora_pre.lora_B.weight) + + lora_pre.train() + lora_post.train() + + out_pre = lora_pre(dummy_input) + out_post = lora_post(dummy_input) + + assert out_pre.shape == out_post.shape + assert not torch.allclose(out_pre, out_post), ( + "Dropout positions should affect output differently" + ) diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index e9d495ab24..99ec9bcfff 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -39,6 +39,7 @@ def create_test_config( activation_checkpointing: bool = False, custom_parallel_plan: str | None = None, dtensor_v2: bool = False, + enable_loras: bool = False, ) -> PolicyConfig: return { "model_name": model_name, @@ -75,6 +76,18 @@ def create_test_config( "tensor_parallel_size": tp, "context_parallel_size": cp, "custom_parallel_plan": custom_parallel_plan, + "lora": { + "enabled": enable_loras, + "target_modules": [], + "exclude_modules": [], + "match_all_linear": True, + "dim": 32, + "alpha": 32, + "dropout": 0.0, + "dropout_position": "post", + "lora_A_init": "xavier", + "use_triton": True, + }, }, "dynamic_batching": { "enabled": True, @@ -106,6 +119,118 @@ def create_test_config( } +def update_lora_config( + config: PolicyConfig, + enabled: bool = True, + target_modules: list[str] = [], + exclude_modules: list[str] = [], + match_all_linear: bool = True, + dim: int = 32, + alpha: int = 32, + dropout: float = 0.0, + dropout_position: str = "post", + lora_A_init: str = "xavier", + use_triton: bool = True, +): + if enabled: + config["dtensor_cfg"]["_v2"] = True + + config["dtensor_cfg"]["lora"].update( + { + "enabled": enabled, + "target_modules": target_modules, + "exclude_modules": exclude_modules, + "match_all_linear": match_all_linear, + "dim": dim, + "alpha": alpha, + "dropout": dropout, + "dropout_position": dropout_position, + "lora_A_init": lora_A_init, + "use_triton": use_triton, + } + ) + + +def _get_use_v2(request) -> bool: + # Get the use_v2 parameter from the test function + marks = getattr(request.function, "pytestmark", []) + for mark in marks: + if ( + hasattr(mark, "args") + and len(mark.args) > 1 + and "use_v2" in str(mark.args[0]) + ): + for p in mark.args[1]: + if isinstance(p, bool): + return p + + # If multiple parametrize decorators, we need to check the node id + if hasattr(request, "node") and hasattr(request.node, "callspec"): + return request.node.callspec.params.get("use_v2", False) + + return False + + +def create_test_batch( + batch_size: int = 8, + seq_len: int = 128, + vocab_size: int = 32000, + mode: str = "train", +) -> BatchedDataDict: + # set random seed + torch.manual_seed(66) + # Create test input_ids and attention_mask + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + attention_mask = torch.ones(batch_size, seq_len) + # Calculate input_lengths (all sequences are full length in this test) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + **( + { + "labels": torch.randint(0, vocab_size, (batch_size, seq_len)), + "sample_mask": torch.ones(batch_size).cuda(), + } + if mode == "train" + else {} + ), + } + ) + data = data.to("cpu") + return data + + +def calculate_token_logprobs(model_name: str, data: BatchedDataDict): + data = data.to("cuda") + input_ids = data["input_ids"] + + with torch.no_grad(): + # run the log prob of regular hf model here + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="cuda", torch_dtype=torch.float32 + ) + hf_model.eval() + outputs = hf_model(**data) + + log_probs = torch.nn.functional.log_softmax( + outputs.logits.to(torch.float32), dim=-1 + ) + next_tokens = input_ids[:, 1:] + log_probs = log_probs[:, :-1] + token_logprobs = log_probs.gather(dim=-1, index=next_tokens.unsqueeze(-1)).squeeze( + -1 + ) + token_logprobs = torch.cat( + [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 + ).cpu() + + data = data.to("cpu") + return token_logprobs + + @pytest.fixture(scope="module") def two_gpu_virtual_cluster(): cluster_name = "test" @@ -122,20 +247,67 @@ def two_gpu_virtual_cluster(): cluster.shutdown() -@pytest.fixture(scope="function") -def gc_collect(): - """Helper function to force garbage collection after a test""" - import gc +@pytest.fixture +def base_setup(request, two_gpu_virtual_cluster): + params = request.param if hasattr(request, "param") else None + assert params is not None, "params is not set" + + mode = params["mode"] + model_fixture_name = params["model_fixture_name"] + specified_config = params["specified_config"] + enable_loras = params["enable_loras"] + lora_config = params["lora_config"] + model_name = request.getfixturevalue(model_fixture_name) - yield - gc.collect() + policy = None + data = None + loss_fn = None + + try: + use_v2 = _get_use_v2(request) + config = create_test_config(model_name, dtensor_v2=use_v2, **specified_config) + + if enable_loras: + update_lora_config(config, **lora_config) + + tokenizer = get_tokenizer(config["tokenizer"]) + print(f"Creating {mode} Policy with {specified_config}...") + policy = Policy( + cluster=two_gpu_virtual_cluster, + config=config, + tokenizer=tokenizer, + init_reference_model=False, + ) + print("Creating test batch...") + data = create_test_batch(mode=mode) + + if mode == "train": + # Create loss function + loss_fn: LossFunction = SimpleLoss() + yield policy, data, loss_fn + elif mode == "logprob": + token_logprobs = calculate_token_logprobs(model_name, data) + yield policy, data, token_logprobs + + except Exception as e: + print(f"Error during setup: {e}") + pytest.skip(f"Setup failed: {e}") + finally: + print("Cleaning up resources for test") + if policy: + policy.shutdown() @pytest.fixture def policy_setup(request, two_gpu_virtual_cluster, tiny_llama_model_path): """Setup and teardown for policy tests - creates a virtual cluster and policy.""" - use_v2 = request.param if hasattr(request, "param") else False - config = create_test_config(tiny_llama_model_path, dtensor_v2=use_v2) + params = request.param if hasattr(request, "param") else {} + use_v2 = params.get("dtensor_v2", False) + enable_loras = params.get("enable_loras", False) + + config = create_test_config( + tiny_llama_model_path, dtensor_v2=use_v2, enable_loras=enable_loras + ) tokenizer = get_tokenizer(config["tokenizer"]) config["generation"] = configure_generation_config(config["generation"], tokenizer) @@ -148,9 +320,221 @@ def policy_setup(request, two_gpu_virtual_cluster, tiny_llama_model_path): policy.shutdown() +@pytest.fixture( + params=[ + # model_fixture_name tp cp sp cpu act + ("tiny_llama_model_path", 1, 1, False, False, False), + ("tiny_llama_model_path", 1, 1, True, False, False), + ("tiny_llama_model_path", 1, 1, False, True, False), + ("tiny_llama_model_path", 1, 1, False, False, True), + ("tiny_llama_model_path", 1, 2, False, False, False), + ("tiny_qwen2_model_path", 1, 1, True, True, False), + ("tiny_qwen2_model_path", 1, 1, True, False, True), + ("tiny_qwen2_model_path", 1, 1, False, True, True), + ("tiny_qwen2_model_path", 1, 1, True, True, True), + ("tiny_qwen2_model_path", 1, 2, False, False, False), + ("tiny_qwen3_model_path", 1, 1, True, True, False), + ("tiny_qwen3_model_path", 1, 1, True, False, True), + ("tiny_qwen3_model_path", 1, 1, False, True, True), + ("tiny_qwen3_model_path", 1, 1, True, True, True), + ("tiny_qwen3_model_path", 1, 2, False, False, False), + ( + "tiny_gemma3_model_path", + 1, + 1, + True, + True, + False, + ), # gemma3 doesn't support spda + ("tiny_gemma3_model_path", 1, 1, True, False, True), + ("tiny_gemma3_model_path", 1, 1, False, True, True), + ("tiny_gemma3_model_path", 1, 1, True, True, True), + # CP doesn't support gemma3 due to spda input has attent_mask != None. + # Nemotron-H doesn't support SP https://github.com/NVIDIA-NeMo/RL/issues/881 + # ("tiny_nemotron5_h_model_path", 1, 1, True, True, False), + # ("tiny_nemotron5_h_model_path", 1, 1, True, False, True), + # ("tiny_nemotron5_h_model_path", 1, 1, True, True, True), + ("tiny_nemotron5_h_model_path", 1, 1, False, False, False), + ("tiny_nemotron5_h_model_path", 1, 1, False, True, True), + # nemotron5_h doesn't support cp + ] +) +def training_setup(request, two_gpu_virtual_cluster): + """Setup and teardown specifically for training tests.""" + request.param = { + "mode": "train", + "enable_loras": False, + "lora_config": None, + "model_fixture_name": request.param[0], + "specified_config": { + "tp": request.param[1], + "cp": request.param[2], + "sp": request.param[3], + "cpu_offload": request.param[4], + "activation_checkpointing": request.param[5], + }, + } + yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) + + +@pytest.fixture( + params=[ + # TP=2, CP=1 + ("tiny_qwen2_model_path", 2, 1, False, True, False), + ("tiny_qwen2_model_path", 2, 1, False, False, False), + ("tiny_llama_model_path", 2, 1, False, False, False), + ("tiny_llama_model_path", 2, 1, False, True, False), + ("tiny_llama_model_path", 2, 1, False, True, True), + ("tiny_qwen3_model_path", 2, 1, False, True, False), + ("tiny_qwen3_model_path", 2, 1, False, False, False), + ("tiny_gemma3_model_path", 2, 1, False, True, False), + ("tiny_gemma3_model_path", 2, 1, False, False, False), + # TP=1, CP=2 + ("tiny_qwen2_model_path", 1, 2, False, True, False), + ("tiny_qwen2_model_path", 1, 2, False, False, False), + ("tiny_llama_model_path", 1, 2, False, False, False), + ("tiny_llama_model_path", 1, 2, False, True, False), + ("tiny_llama_model_path", 1, 2, False, True, True), + ("tiny_qwen3_model_path", 1, 2, False, True, False), + ("tiny_qwen3_model_path", 1, 2, False, False, False), + ] +) +def logprob_setup(request, two_gpu_virtual_cluster): + """Setup and teardown specifically for logprob tests.""" + request.param = { + "mode": "logprob", + "enable_loras": False, + "lora_config": None, + "model_fixture_name": request.param[0], + "specified_config": { + "tp": request.param[1], + "cp": request.param[2], + "sp": request.param[3], + "cpu_offload": request.param[4], + "activation_checkpointing": request.param[5], + }, + } + yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) + + +@pytest.fixture( + params=[ + # model_name, target_modules, exclude_modules, match_all_linear, dim, alpha, dropout, dropout_position, lora_A_init, use_triton + ("tiny_llama_model_path", [], [], True, 16, 32, 0.0, "post", "xavier", True), + ("tiny_qwen2_model_path", [], [], True, 32, 32, 0.0, "pre", "xavier", True), + ( + "tiny_qwen2_model_path", + ["q_proj", "k_proj", "*gate_proj*", "*up_proj*", "*down_proj*"], + [], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ( + "tiny_qwen2_model_path", + [], + ["q_proj", "k_proj"], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ] +) +def training_with_lora_setup(request, two_gpu_virtual_cluster): + """Setup and teardown specifically for training with lora tests.""" + request.param = { + "mode": "train", + "enable_loras": True, + "model_fixture_name": request.param[0], + "specified_config": {}, + "lora_config": { + "target_modules": request.param[1], + "exclude_modules": request.param[2], + "match_all_linear": request.param[3], + "dim": request.param[4], + "alpha": request.param[5], + "dropout": request.param[6], + "dropout_position": request.param[7], + "lora_A_init": request.param[8], + "use_triton": request.param[9], + }, + } + yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) + + +@pytest.fixture( + params=[ + # model_name, target_modules, exclude_modules, match_all_linear, dim, alpha, dropout, dropout_position, lora_A_init, use_triton + ("tiny_llama_model_path", [], [], True, 16, 32, 0.0, "post", "xavier", True), + ("tiny_qwen2_model_path", [], [], True, 32, 32, 0.0, "pre", "xavier", True), + ( + "tiny_qwen2_model_path", + ["q_proj", "k_proj", "*gate_proj*", "*up_proj*", "*down_proj*"], + [], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ( + "tiny_qwen2_model_path", + [], + ["q_proj", "k_proj"], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ] +) +def logprob_with_lora_setup(request, two_gpu_virtual_cluster): + """Setup and teardown specifically for logprob with lora tests.""" + request.param = { + "mode": "logprob", + "enable_loras": True, + "model_fixture_name": request.param[0], + "specified_config": {}, + "lora_config": { + "target_modules": request.param[1], + "exclude_modules": request.param[2], + "match_all_linear": request.param[3], + "dim": request.param[4], + "alpha": request.param[5], + "dropout": request.param[6], + "dropout_position": request.param[7], + "lora_A_init": request.param[8], + "use_triton": request.param[9], + }, + } + yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) + + @pytest.mark.hf_gated @pytest.mark.timeout(360) -@pytest.mark.parametrize("policy_setup", [True, False], indirect=True) +# @pytest.mark.parametrize("policy_setup", [True, False], indirect=True) +@pytest.mark.parametrize( + "policy_setup", + [ + {"dtensor_v2": True, "enable_loras": False}, + {"dtensor_v2": True, "enable_loras": True}, + {"dtensor_v2": False, "enable_loras": False}, + ], + indirect=True, +) def test_lm_policy_init(policy_setup): policy = policy_setup @@ -227,153 +611,12 @@ def test_lm_policy_init(policy_setup): ) -@pytest.fixture -def training_setup(request, two_gpu_virtual_cluster): - """Setup and teardown specifically for training tests.""" - # Get the use_v2 parameter from the test function - use_v2 = getattr(request.function, "pytestmark", []) - use_v2_value = False - for mark in use_v2: - if ( - hasattr(mark, "args") - and len(mark.args) > 1 - and "use_v2" in str(mark.args[0]) - ): - for param_set in mark.args[1]: - if isinstance(param_set, bool): - use_v2_value = param_set - break - - # If multiple parametrize decorators, we need to check the node id - if hasattr(request, "node") and hasattr(request.node, "callspec"): - if "use_v2" in request.node.callspec.params: - use_v2_value = request.node.callspec.params["use_v2"] - - ( - model_fixture_name, - tp, - cp, - sp, - cpu_offload, - activation_checkpointing, - ) = request.param - - # Get the actual model path from the requested fixture - model_name = request.getfixturevalue(model_fixture_name) - policy = None - data = None - loss_fn = None - - try: - config = create_test_config( - model_name, - tp, - cp, - sp, - cpu_offload, - activation_checkpointing, - dtensor_v2=use_v2_value, - ) - tokenizer = get_tokenizer(config["tokenizer"]) - print( - f"Creating training Policy with tp={tp}, cpu_offload={cpu_offload}, sequence_parallel={sp}, activation_checkpointing={activation_checkpointing}..." - ) - policy = Policy( - cluster=two_gpu_virtual_cluster, - config=config, - tokenizer=tokenizer, - init_reference_model=False, - ) - - # Create a test batch - print("Creating test batch...") - # set random seed - torch.manual_seed(42) - - # Create test input_ids and attention_mask - input_ids = torch.randint(0, 32000, (8, 128)) # 8 sequences, each of length 128 - attention_mask = torch.ones(8, 128) - - # Calculate input_lengths (all sequences are full length in this test) - input_lengths = attention_mask.sum(dim=1).to(torch.int32) - - data = BatchedDataDict( - { - "input_ids": input_ids, - "input_lengths": input_lengths, - "attention_mask": attention_mask, # Keep for compatibility with loss functions - "labels": torch.randint(0, 32000, (8, 128)), - "sample_mask": torch.ones(8), - } - ) - - # Create loss function - loss_fn: LossFunction = SimpleLoss() - - # Provide the resources to the test - yield policy, data, loss_fn - - except Exception as e: - print(f"Error during training setup: {e}") - pytest.skip(f"Training setup failed: {e}") - finally: - # Clean up after the test - print("Cleaning up resources for test") - policy.shutdown() - - -@pytest.mark.hf_gated -@pytest.mark.timeout(360) -@pytest.mark.parametrize("use_v2", [True, False]) -@pytest.mark.parametrize( - "training_setup", - [ - # model_fixture_name tp cp sp cpu act - ("tiny_llama_model_path", 1, 1, False, False, False), - ("tiny_llama_model_path", 1, 1, True, False, False), - ("tiny_llama_model_path", 1, 1, False, True, False), - ("tiny_llama_model_path", 1, 1, False, False, True), - ("tiny_llama_model_path", 1, 2, False, False, False), - ("tiny_qwen2_model_path", 1, 1, True, True, False), - ("tiny_qwen2_model_path", 1, 1, True, False, True), - ("tiny_qwen2_model_path", 1, 1, False, True, True), - ("tiny_qwen2_model_path", 1, 1, True, True, True), - ("tiny_qwen2_model_path", 1, 2, False, False, False), - ("tiny_qwen3_model_path", 1, 1, True, True, False), - ("tiny_qwen3_model_path", 1, 1, True, False, True), - ("tiny_qwen3_model_path", 1, 1, False, True, True), - ("tiny_qwen3_model_path", 1, 1, True, True, True), - ("tiny_qwen3_model_path", 1, 2, False, False, False), - ( - "tiny_gemma3_model_path", - 1, - 1, - True, - True, - False, - ), # gemma3 doesn't support spda - ("tiny_gemma3_model_path", 1, 1, True, False, True), - ("tiny_gemma3_model_path", 1, 1, False, True, True), - ("tiny_gemma3_model_path", 1, 1, True, True, True), - # CP doesn't support gemma3 due to spda input has attent_mask != None. - # Nemotron-H doesn't support SP https://github.com/NVIDIA-NeMo/RL/issues/881 - # ("tiny_nemotron5_h_model_path", 1, 1, True, True, False), - # ("tiny_nemotron5_h_model_path", 1, 1, True, False, True), - # ("tiny_nemotron5_h_model_path", 1, 1, True, True, True), - ("tiny_nemotron5_h_model_path", 1, 1, False, False, False), - ("tiny_nemotron5_h_model_path", 1, 1, False, True, True), - # nemotron5_h doesn't support cp - ], - indirect=True, -) -def test_dtensor_worker_training(use_v2, training_setup): +def _test_dtensor_worker_training(policy, data, loss_fn): def verify_loss_tensor(loss_tensor): assert not torch.isnan(loss_tensor).any(), "Loss should not be NaN" assert not torch.isinf(loss_tensor).any(), "Loss should not be Inf" return loss_tensor - policy, data, loss_fn = training_setup - # Verify resources were created properly assert policy is not None, "Training policy was not created properly" assert data is not None, "Test data was not created properly" @@ -423,136 +666,22 @@ def verify_loss_tensor(loss_tensor): ) -@pytest.fixture -def logprob_setup(request, two_gpu_virtual_cluster): - """Setup and teardown specifically for training tests.""" - # Get the use_v2 parameter from the test function - use_v2_value = False - if hasattr(request, "node") and hasattr(request.node, "callspec"): - if "use_v2" in request.node.callspec.params: - use_v2_value = request.node.callspec.params["use_v2"] - - ( - model_fixture_name, - tp, - cp, - sp, - cpu_offload, - activation_checkpointing, - ) = request.param - - # Get the actual model path from the requested fixture - model_name = request.getfixturevalue(model_fixture_name) - policy = None - data = None - - try: - config = create_test_config( - model_name, - tp, - cp, - sp, - cpu_offload, - activation_checkpointing, - dtensor_v2=use_v2_value, - ) - tokenizer = get_tokenizer(config["tokenizer"]) - print( - f"Creating logprob Policy with tp={tp}, cpu_offload={cpu_offload}, sequence_parallel={sp}, activation_checkpointing={activation_checkpointing}..." - ) - policy = Policy( - cluster=two_gpu_virtual_cluster, - config=config, - tokenizer=tokenizer, - init_reference_model=False, - ) - - # Create a test batch - print("Creating test batch...") - # set random seed - torch.manual_seed(66) - - # Create test input_ids and attention_mask - input_ids = torch.randint( - 0, 32000, (8, 128) - ).cuda() # 8 sequences, each of length 128 - attention_mask = torch.ones(8, 128).cuda() - - # Calculate input_lengths (all sequences are full length in this test) - input_lengths = attention_mask.sum(dim=1).to(torch.int32).cuda() - - data = BatchedDataDict( - { - "input_ids": input_ids, - "input_lengths": input_lengths, - "attention_mask": attention_mask, # Keep for compatibility with loss functions - } - ) - - with torch.no_grad(): - # run the log prob of regular hf model here - hf_model = AutoModelForCausalLM.from_pretrained( - model_name, device_map="cuda", torch_dtype=torch.float32 - ) - hf_model.eval() - outputs = hf_model(**data) - - log_probs = torch.nn.functional.log_softmax( - outputs.logits.to(torch.float32), dim=-1 - ) - next_tokens = input_ids[:, 1:] - log_probs = log_probs[:, :-1] - token_logprobs = log_probs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) - token_logprobs = torch.cat( - [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 - ).cpu() - - data = data.to("cpu") - - # Provide the resources to the test - yield policy, data, token_logprobs - - except Exception as e: - print(f"Error during training setup: {e}") - pytest.skip(f"Training setup failed: {e}") - finally: - # Clean up after the test - print("Cleaning up resources for test") - policy.shutdown() +@pytest.mark.hf_gated +@pytest.mark.timeout(360) +@pytest.mark.parametrize("use_v2", [True, False]) +def test_dtensor_worker_training(use_v2, training_setup): + policy, data, loss_fn = training_setup + _test_dtensor_worker_training(policy, data, loss_fn) @pytest.mark.hf_gated @pytest.mark.timeout(360) -@pytest.mark.parametrize("use_v2", [True, False]) -@pytest.mark.parametrize( - "logprob_setup", - [ - # TP=2, CP=1 - ("tiny_qwen2_model_path", 2, 1, False, True, False), - ("tiny_qwen2_model_path", 2, 1, False, False, False), - ("tiny_llama_model_path", 2, 1, False, False, False), - ("tiny_llama_model_path", 2, 1, False, True, False), - ("tiny_llama_model_path", 2, 1, False, True, True), - ("tiny_qwen3_model_path", 2, 1, False, True, False), - ("tiny_qwen3_model_path", 2, 1, False, False, False), - ("tiny_gemma3_model_path", 2, 1, False, True, False), - ("tiny_gemma3_model_path", 2, 1, False, False, False), - # TP=1, CP=2 - ("tiny_qwen2_model_path", 1, 2, False, True, False), - ("tiny_qwen2_model_path", 1, 2, False, False, False), - ("tiny_llama_model_path", 1, 2, False, False, False), - ("tiny_llama_model_path", 1, 2, False, True, False), - ("tiny_llama_model_path", 1, 2, False, True, True), - ("tiny_qwen3_model_path", 1, 2, False, True, False), - ("tiny_qwen3_model_path", 1, 2, False, False, False), - ], - indirect=True, -) -def test_dtensor_worker_logprob_tp2_or_cp2_matches_unsharded(use_v2, logprob_setup): - policy, data, logprobs = logprob_setup +def test_dtensor_worker_training_with_lora(training_with_lora_setup): + policy, data, loss_fn = training_with_lora_setup + _test_dtensor_worker_training(policy, data, loss_fn) + +def _test_dtensor_worker_logprob(policy, data, logprobs): # Verify resources were created properly assert policy is not None, "Policy was not created properly" assert data is not None, "Test data was not created properly" @@ -567,6 +696,21 @@ def test_dtensor_worker_logprob_tp2_or_cp2_matches_unsharded(use_v2, logprob_set ) +@pytest.mark.hf_gated +@pytest.mark.timeout(360) +@pytest.mark.parametrize("use_v2", [True, False]) +def test_dtensor_worker_logprob_tp2_or_cp2_matches_unsharded(use_v2, logprob_setup): + policy, data, logprobs = logprob_setup + _test_dtensor_worker_logprob(policy, data, logprobs) + + +@pytest.mark.hf_gated +@pytest.mark.timeout(360) +def test_dtensor_worker_logprob_with_lora(logprob_with_lora_setup): + policy, data, logprobs = logprob_with_lora_setup + _test_dtensor_worker_logprob(policy, data, logprobs) + + @pytest.mark.hf_gated @pytest.mark.parametrize("use_v2", [True, False]) def test_dtensor_tp_and_tied_model_with_custom_parallel_plan( diff --git a/tests/unit/utils/test_automodel_checkpoint.py b/tests/unit/utils/test_automodel_checkpoint.py index 9906a1522f..ba2980908f 100644 --- a/tests/unit/utils/test_automodel_checkpoint.py +++ b/tests/unit/utils/test_automodel_checkpoint.py @@ -26,6 +26,11 @@ except ImportError: pytest.skip("nemo_automodel not available", allow_module_level=True) +from nemo_automodel.components._peft.lora import ( + PeftConfig, + apply_lora_to_linear_modules, +) + from nemo_rl.utils.automodel_checkpoint import ( detect_checkpoint_format, load_checkpoint, @@ -52,6 +57,9 @@ def forward(self, x): x = layer(x) return x + def apply_lora(self, lora_config: PeftConfig): + apply_lora_to_linear_modules(self, lora_config) + @pytest.fixture def mock_model(): @@ -66,6 +74,53 @@ def mock_optimizer(): return torch.optim.Adam(model.parameters()) +@pytest.fixture +def mock_lora_config(): + """Create a simple mock LORA configuration for testing.""" + return PeftConfig( + target_modules=[], + match_all_linear=True, + dim=2, + alpha=2, + dropout=0.1, + dropout_position="post", + lora_A_init="xavier", + use_triton=False, + ) + + +@pytest.fixture +def mock_distributed(): + """Mock torch.distributed calls for non-distributed tests.""" + with ( + patch("torch.distributed.is_initialized", return_value=False), + patch("torch.distributed.get_rank", return_value=0), + ): + yield + torch.distributed.destroy_process_group() + + +@pytest.fixture +def init_distributed(): + """Initialize a single-process distributed environment for testing.""" + + # Only initialize if not already initialized + if not torch.distributed.is_initialized(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" # Free port + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + # Use gloo backend for CPU-only tests + torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1) + + yield + + # Cleanup + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + @pytest.mark.automodel class TestDetectCheckpointFormat: """Test the detect_checkpoint_format function.""" @@ -418,3 +473,80 @@ def test_save_and_load_model_and_optimizer(self, mock_experiment): check_dict_equality(new_model.state_dict(), original_model_state) check_dict_equality(new_optimizer.state_dict(), original_optimizer_state) assert new_scheduler.state_dict() == original_scheduler_state + + def test_save_and_load_model_with_lora( + self, mock_experiment, mock_lora_config, init_distributed + ): + """Test saving and loading both model and optimizer with LORA.""" + test_model, _, _ = mock_experiment + lora_config = mock_lora_config + + test_model.apply_lora(lora_config) + lora_state_dict = test_model.state_dict() + + # Assert LoRA weights exist for layers.0 (Linear 4->4) + assert "layers.0.lora_A.weight" in lora_state_dict, ( + "layers.0.lora_A.weight not found" + ) + assert "layers.0.lora_B.weight" in lora_state_dict, ( + "layers.0.lora_B.weight not found" + ) + + # Assert LoRA weights exist for layers.3 (Linear 4->1) + assert "layers.3.lora_A.weight" in lora_state_dict, ( + "layers.3.lora_A.weight not found" + ) + assert "layers.3.lora_B.weight" in lora_state_dict, ( + "layers.3.lora_B.weight not found" + ) + + assert lora_state_dict["layers.0.lora_A.weight"].shape == (2, 4), ( + f"Expected layers.0.lora_A.weight shape (2, 4), got {lora_state_dict['layers.0.lora_A.weight'].shape}" + ) + assert lora_state_dict["layers.0.lora_B.weight"].shape == (4, 2), ( + f"Expected layers.0.lora_B.weight shape (4, 2), got {lora_state_dict['layers.0.lora_B.weight'].shape}" + ) + + # For layers.3: Linear(4, 1) with dim=2 + # lora_A: (dim, in_features) = (2, 4) + # lora_B: (out_features, dim) = (1, 2) + assert lora_state_dict["layers.3.lora_A.weight"].shape == (2, 4), ( + f"Expected layers.3.lora_A.weight shape (2, 4), got {lora_state_dict['layers.3.lora_A.weight'].shape}" + ) + assert lora_state_dict["layers.3.lora_B.weight"].shape == (1, 2), ( + f"Expected layers.3.lora_B.weight shape (1, 2), got {lora_state_dict['layers.3.lora_B.weight'].shape}" + ) + + with TemporaryDirectory() as tmp_dir: + weights_path = os.path.join(tmp_dir, "test_model") + save_checkpoint( + model=test_model, + weights_path=weights_path, + model_save_format="safetensors", + is_peft=True, + peft_config=lora_config, + ) + + # Verify files are created + assert os.path.exists(weights_path) + files = os.listdir(os.path.join(weights_path, "model")) + assert any(f.endswith(".safetensors") for f in files) + + # Create a new model with different weights + new_model = TestModel() + new_model.apply_lora(lora_config) + # Initialize with different values + for param in new_model.parameters(): + param.data.fill_(999.0) + + # Load the checkpoint for peft need distributed(refer to nemo_automodel/components/checkpoint/stateful_wrappers.py:load_state_dict) + load_checkpoint(model=new_model, weights_path=weights_path) + # peft only save lora weights, so we need to filter out the non-lora weights + lora_params_original = { + k: v for k, v in lora_state_dict.items() if "lora" in k + } + lora_params_loaded = { + k: v for k, v in new_model.state_dict().items() if "lora" in k + } + # Verify the weights match the original + check_dict_equality(lora_params_loaded, lora_params_original)