Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ checkpointing:
checkpoint_dir: results/sft-llama3.2-1b-1n8g-fsdp2tp1
save_period: 100
policy:
dtensor_cfg:
lora:
dim: 32
tokenizer:
name: meta-llama/Llama-3.2-1B
make_sequence_length_divisible_by: 1
Expand Down
13 changes: 13 additions & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ policy:
offload_optimizer_for_logprob: false

dtensor_cfg:
_v2: true
enabled: true
env_vars: {}
cpu_offload: False
Expand All @@ -44,6 +45,18 @@ policy:
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null
lora:
enabled: false
target_modules: [] # match all linear modules takes precendence
exclude_modules: []
match_all_linear: true
dim: 8
alpha: 32
dropout: 0.0
dropout_position: "post"
lora_A_init: "xavier"
lora_dtype: ${policy.precision}
use_triton: true

dynamic_batching:
enabled: false
Expand Down
14 changes: 14 additions & 0 deletions nemo_rl/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ class DTensorConfigDisabled(TypedDict):
enabled: Literal[False]


class LoRAConfig(TypedDict):
enabled: bool
target_modules: NotRequired[list[str]]
exclude_modules: NotRequired[list[str]]
match_all_linear: NotRequired[bool]
dim: NotRequired[int]
alpha: NotRequired[int]
dropout: NotRequired[float]
dropout_position: NotRequired[Literal["pre", "post"]]
lora_A_init: NotRequired[str]
use_triton: NotRequired[bool]
Comment on lines +25 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please help document these flags?


Comment on lines +24 to +35
Copy link
Contributor

@coderabbitai coderabbitai bot Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing lora_dtype field in LoRAConfig TypedDict.

The sft.yaml configuration includes lora_dtype: ${policy.precision}, but this field is not defined in the LoRAConfig TypedDict. This will cause type-checking inconsistencies.

Apply this diff to add the missing field:

 class LoRAConfig(TypedDict):
     enabled: bool
     target_modules: NotRequired[list[str]]
     exclude_modules: NotRequired[list[str]]
     match_all_linear: NotRequired[bool]
     dim: NotRequired[int]
     alpha: NotRequired[int]
     dropout: NotRequired[float]
     dropout_position: NotRequired[Literal["pre", "post"]]
     lora_A_init: NotRequired[str]
+    lora_dtype: NotRequired[str]
     use_triton: NotRequired[bool]
🤖 Prompt for AI Agents
In nemo_rl/models/policy/__init__.py around lines 24 to 35, the LoRAConfig
TypedDict is missing the lora_dtype field referenced in sft.yaml; add a
NotRequired[literal union or type alias] entry named lora_dtype (e.g.
NotRequired[str] or NotRequired[torch.dtype/str matching project conventions])
to the TypedDict so type-checkers recognize the configuration key, keeping
optionality consistent with other NotRequired fields.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to fix this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find and examine sft.yaml to verify lora_dtype usage
fd -t f "sft.yaml" --exec cat {}

Length of output: 7258


@joyang-nv Confirmed! I verified that lora_dtype: ${policy.precision} is indeed used in examples/configs/sft.yaml (line 60) but is missing from the LoRAConfig TypedDict.

Here's the fix to add the missing field:

class LoRAConfig(TypedDict):
    enabled: bool
    target_modules: NotRequired[list[str]]
    exclude_modules: NotRequired[list[str]]
    match_all_linear: NotRequired[bool]
    dim: NotRequired[int]
    alpha: NotRequired[int]
    dropout: NotRequired[float]
    dropout_position: NotRequired[Literal["pre", "post"]]
    lora_A_init: NotRequired[str]
    lora_dtype: NotRequired[str]
    use_triton: NotRequired[bool]

This will ensure type consistency with the configuration files.


🧬 Code Graph Analysis Results
  • File: nemo_rl/models/generation/interfaces.py, lines 118-131
class GenerationConfig(TypedDict):
    """Configuration for generation."""

    backend: str
    max_new_tokens: int
    temperature: float
    top_p: float
    top_k: int | None
    model_name: NotRequired[str]  # Not Required b/c GRPO writes this
    stop_token_ids: list[int] | None
    stop_strings: list[str] | None
    colocated: NotRequired[ColocationConfig]
    # This isn't meant to be passed by the user, but is populated by nemo_rl.models.generation.__init__.configure_generation_config
    _pad_token_id: NotRequired[int]
  • File: nemo_rl/models/policy/megatron_policy_worker.py, lines 260-272
def freeze_moe_router(megatron_model):
            if not isinstance(megatron_model, list):
                megatron_model = [megatron_model]
            for model_module in megatron_model:
                # Handle both wrapped (Float16Module) and unwrapped models
                if isinstance(model_module, Float16Module):
                    model_module = model_module.module
                # Handle VLM models
                if hasattr(model_module, "language_model"):
                    model_module = model_module.language_model
                for layer in model_module.decoder.layers:
                    if hasattr(layer, "mlp") and hasattr(layer.mlp, "router"):
                        layer.mlp.router.weight.requires_grad = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make all of these required if they are required when enabled=True? you can get the type checker to respect notrequired if false and required if true by doing something like this:

https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/policy/__init__.py#L191

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see later on there's this comment about the lora_dtype
image

but will that override the value chosen here. +1 on resolving this before merging

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!


class DTensorConfig(TypedDict):
enabled: Literal[True]
env_vars: NotRequired[dict[str, str] | None]
Expand All @@ -32,6 +45,7 @@ class DTensorConfig(TypedDict):
context_parallel_size: int
custom_parallel_plan: str | None
clear_cache_every_n_steps: NotRequired[int | None]
lora: NotRequired[LoRAConfig | None]


class SequencePackingConfigDisabled(TypedDict):
Expand Down
36 changes: 36 additions & 0 deletions nemo_rl/models/policy/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,25 @@

import gc
import itertools
import math
import os
import warnings
from collections import defaultdict
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import Any, Generator, Optional, cast

import nemo_automodel.components._peft.lora as _lora_mod
import ray
import torch
import zmq
from accelerate import init_empty_weights
from nemo_automodel import (
NeMoAutoModelForSequenceClassification,
)
from nemo_automodel.components._peft.lora import (
PeftConfig,
apply_lora_to_linear_modules,
)
from nemo_automodel.components.distributed.cp_utils import (
create_context_parallel_ctx,
get_train_context,
Expand Down Expand Up @@ -94,6 +100,15 @@
from nemo_rl.utils.packed_tensor import packed_broadcast_producer


# TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc (https://github.com/NVIDIA-NeMo/RL/issues/1586)
def _patched_init_lora_weights(self, init_method: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can see this function silently getting ignored. could you add a unit test that does a check and will fail when it's okay to remove? something in spirit to this

"If this fails, that means the upstream bug has been fixed. You can close this issue: https://github.com/huggingface/transformers/issues/41190"

if init_method == "xavier":
nn.init.xavier_normal_(self.lora_A.weight.data)
else:
nn.init.kaiming_uniform_(self.lora_A.weight.data, a=math.sqrt(5))
self.lora_B.weight.data.zero_()


@ray.remote(
runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2")
) # pragma: no cover
Expand Down Expand Up @@ -223,6 +238,19 @@ def __init__(

full_state_dict = None
model_state_dict_keys = None

# lora config
lora_cfg = self.cfg["dtensor_cfg"].get("lora", None)
self.peft_config = None
self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"]
# patch the init_lora_weights method to use the xavier initialization
_lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights
Comment on lines +246 to +247
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Global monkey-patch affects all LoRA instances in the process.

Patching _lora_mod.LinearLoRA.init_lora_weights globally affects all LoRA instances across the entire process, not just this worker. This could cause issues if multiple workers or other code paths rely on the original initialization behavior.

Consider applying the patch in a more scoped manner, or add a guard to prevent re-patching.

+# Guard to prevent re-patching in multi-worker scenarios
+_LORA_INIT_PATCHED = False
+
 # TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc
 def _patched_init_lora_weights(self, init_method: str):
     ...

 ...
-        # patch the init_lora_weights method to use the xavier initialization
-        _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights
+        # patch the init_lora_weights method to use the xavier initialization
+        global _LORA_INIT_PATCHED
+        if not _LORA_INIT_PATCHED:
+            _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights
+            _LORA_INIT_PATCHED = True

if self.lora_enabled:
# Always use float32 since FSDP requires all parameters to be in the same dtype.
# autocast should cast the weights to the correct dtype during the forward pass.
cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"}
self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype)

if self.rank == 0:
print(f"[Rank {self.rank}] Loading model {model_name} on CPU...")
model = model_class.from_pretrained(
Expand All @@ -234,6 +262,9 @@ def __init__(
torch_dtype=str(model_config.torch_dtype),
)

if self.peft_config is not None:
apply_lora_to_linear_modules(model, self.peft_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this function is called twice (on this line and line 291). Is that expected?


full_state_dict = model.state_dict()
# Store the original model state dict keys before any parallelization
model_state_dict_keys = list(full_state_dict.keys())
Expand All @@ -256,6 +287,8 @@ def __init__(
trust_remote_code=True,
torch_dtype=str(model_config.torch_dtype),
)
if self.lora_enabled:
apply_lora_to_linear_modules(self.model, self.peft_config)

if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = tokenizer.pad_token_id
Expand Down Expand Up @@ -1894,6 +1927,9 @@ def save_checkpoint(
"peft_config",
}
}
if self.lora_enabled:
checkpoint_kwargs["is_peft"] = True
checkpoint_kwargs["peft_config"] = self.peft_config

save_checkpoint(
model=self.model,
Expand Down
3 changes: 3 additions & 0 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __init__(
if use_v2:
worker_builder_cls = "nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2"
else:
assert config.get("lora", {}).get("enabled", False) is False, (
"LoRA is not supported for DTensorPolicyWorker V1"
)
Comment on lines +117 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: Incorrect config path for LoRA check.

The assertion checks config.get("lora", ...) but LoRA configuration is nested under dtensor_cfg.lora, not at the top level. This means the guard will never trigger, allowing LoRA to be incorrectly used with V1.

Apply this diff to fix the config path:

-                assert config.get("lora", {}).get("enabled", False) is False, (
+                assert config.get("dtensor_cfg", {}).get("lora", {}).get("enabled", False) is False, (
                     "LoRA is not supported for DTensorPolicyWorker V1"
                 )
🤖 Prompt for AI Agents
In nemo_rl/models/policy/lm_policy.py around lines 117 to 119, the assertion is
checking the wrong config path — it reads config.get("lora", ...) but LoRA lives
under config["dtensor_cfg"]["lora"]; update the assertion to check
config.get("dtensor_cfg", {}).get("lora", {}).get("enabled", False) is False so
the guard correctly prevents LoRA from being enabled for DTensorPolicyWorker V1.

worker_builder_cls = (
"nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker"
)
Expand Down
46 changes: 46 additions & 0 deletions tests/functional/test_automodel_lora_sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash

# clean up checkpoint directory on exit
trap "rm -rf /tmp/lora_sft_checkpoints" EXIT

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
# Mark the current repo as safe, since wandb fetches metadata about the repo
git config --global --add safe.directory $PROJECT_ROOT

set -eou pipefail

EXP_NAME=$(basename $0 .sh)
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
LOG_DIR=$EXP_DIR/logs
JSON_METRICS=$EXP_DIR/metrics.json
RUN_LOG=$EXP_DIR/run.log
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}

rm -rf $EXP_DIR $LOG_DIR
mkdir -p $EXP_DIR $LOG_DIR

cd $PROJECT_ROOT
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
$PROJECT_ROOT/examples/run_sft.py \
policy.model_name=Qwen/Qwen3-0.6B \
cluster.gpus_per_node=2 \
sft.max_num_steps=3 \
sft.val_batches=1 \
sft.val_period=3 \
policy.dtensor_cfg.lora.enabled=true \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
logger.monitor_gpus=true \
checkpointing.enabled=true \
checkpointing.save_period=3 \
checkpointing.checkpoint_dir=/tmp/lora_sft_checkpoints \
$@ \
Copy link
Contributor

@coderabbitai coderabbitai bot Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Quote the array expansion to prevent word splitting.

Static analysis correctly identifies that $@ should be quoted as "$@" to avoid re-splitting elements containing spaces.

-    $@ \
+    "$@" \
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
$@ \
"$@" \
🧰 Tools
🪛 Shellcheck (0.11.0)

[error] 39-39: Double quote array expansions to avoid re-splitting elements.

(SC2068)

🤖 Prompt for AI Agents
In tests/functional/test_automodel_lora_sft.sh around line 39, the unquoted
array expansion `$@` may cause word splitting; change it to `"$@"` in the
command invocation so each original argument is preserved as a single parameter
(i.e., replace `$@` with `"$@"`).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I think this is a good catch!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joyang-nv, thank you for confirming! I noticed something interesting though - there's a learning in the knowledge base (ID: ccad141e-7829-44b1-98d0-ec79c01decb2) that states:

"In the NVIDIA-NeMo/RL repository, test scripts under tests/ follow a consistent pattern: use cd $PROJECT_ROOT without quotes or error handling, and pass arguments with $@ unquoted."

This learning conflicts with the current suggestion to quote $@. Since you've confirmed that quoting is the correct approach, should I remove this obsolete learning from the knowledge base? This will help ensure future reviews recommend the proper shell scripting practice of quoting array expansions.


🧠 Learnings used
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:16-30
Timestamp: 2025-10-12T14:46:55.513Z
Learning: In the NVIDIA-NeMo/RL repository, test scripts under tests/ follow a consistent pattern: use `cd $PROJECT_ROOT` without quotes or error handling, and pass arguments with `$@` unquoted. Maintain this consistency when adding new test scripts.

Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to tests/test_suites/**/*.sh : Driver shell scripts should match the YAML base name with .sh extension and invoke training entrypoint with uv run

2>&1 | tee $RUN_LOG

uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

uv run tests/check_metrics.py $JSON_METRICS \
'data["train/loss"]["3"] < 5.9'

Loading
Loading