-
Notifications
You must be signed in to change notification settings - Fork 175
feat: LoRA SFT support for DTensorV2 path #1556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c4f13da
286b471
509e0ad
ef5e932
bac01be
641b985
b1a0fb6
20c357c
b33ce28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,19 @@ class DTensorConfigDisabled(TypedDict): | |
| enabled: Literal[False] | ||
|
|
||
|
|
||
| class LoRAConfig(TypedDict): | ||
| enabled: bool | ||
| target_modules: NotRequired[list[str]] | ||
| exclude_modules: NotRequired[list[str]] | ||
| match_all_linear: NotRequired[bool] | ||
| dim: NotRequired[int] | ||
| alpha: NotRequired[int] | ||
| dropout: NotRequired[float] | ||
| dropout_position: NotRequired[Literal["pre", "post"]] | ||
| lora_A_init: NotRequired[str] | ||
| use_triton: NotRequired[bool] | ||
|
|
||
|
Comment on lines
+24
to
+35
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing The Apply this diff to add the missing field: class LoRAConfig(TypedDict):
enabled: bool
target_modules: NotRequired[list[str]]
exclude_modules: NotRequired[list[str]]
match_all_linear: NotRequired[bool]
dim: NotRequired[int]
alpha: NotRequired[int]
dropout: NotRequired[float]
dropout_position: NotRequired[Literal["pre", "post"]]
lora_A_init: NotRequired[str]
+ lora_dtype: NotRequired[str]
use_triton: NotRequired[bool]🤖 Prompt for AI Agents
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to fix this.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: Length of output: 7258 @joyang-nv Confirmed! I verified that Here's the fix to add the missing field: class LoRAConfig(TypedDict):
enabled: bool
target_modules: NotRequired[list[str]]
exclude_modules: NotRequired[list[str]]
match_all_linear: NotRequired[bool]
dim: NotRequired[int]
alpha: NotRequired[int]
dropout: NotRequired[float]
dropout_position: NotRequired[Literal["pre", "post"]]
lora_A_init: NotRequired[str]
lora_dtype: NotRequired[str]
use_triton: NotRequired[bool]This will ensure type consistency with the configuration files. 🧬 Code Graph Analysis Results
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you make all of these required if they are required when https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/policy/__init__.py#L191
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| class DTensorConfig(TypedDict): | ||
| enabled: Literal[True] | ||
| env_vars: NotRequired[dict[str, str] | None] | ||
|
|
@@ -32,6 +45,7 @@ class DTensorConfig(TypedDict): | |
| context_parallel_size: int | ||
| custom_parallel_plan: str | None | ||
| clear_cache_every_n_steps: NotRequired[int | None] | ||
| lora: NotRequired[LoRAConfig | None] | ||
|
|
||
|
|
||
| class SequencePackingConfigDisabled(TypedDict): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -14,19 +14,25 @@ | |||
|
|
||||
| import gc | ||||
| import itertools | ||||
| import math | ||||
| import os | ||||
| import warnings | ||||
| from collections import defaultdict | ||||
| from contextlib import AbstractContextManager, contextmanager, nullcontext | ||||
| from typing import Any, Generator, Optional, cast | ||||
|
|
||||
| import nemo_automodel.components._peft.lora as _lora_mod | ||||
| import ray | ||||
| import torch | ||||
| import zmq | ||||
| from accelerate import init_empty_weights | ||||
| from nemo_automodel import ( | ||||
| NeMoAutoModelForSequenceClassification, | ||||
| ) | ||||
| from nemo_automodel.components._peft.lora import ( | ||||
| PeftConfig, | ||||
| apply_lora_to_linear_modules, | ||||
| ) | ||||
| from nemo_automodel.components.distributed.cp_utils import ( | ||||
| create_context_parallel_ctx, | ||||
| get_train_context, | ||||
|
|
@@ -94,6 +100,15 @@ | |||
| from nemo_rl.utils.packed_tensor import packed_broadcast_producer | ||||
|
|
||||
|
|
||||
| # TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc (https://github.com/NVIDIA-NeMo/RL/issues/1586) | ||||
| def _patched_init_lora_weights(self, init_method: str): | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i can see this function silently getting ignored. could you add a unit test that does a check and will fail when it's okay to remove? something in spirit to this
|
||||
| if init_method == "xavier": | ||||
| nn.init.xavier_normal_(self.lora_A.weight.data) | ||||
| else: | ||||
| nn.init.kaiming_uniform_(self.lora_A.weight.data, a=math.sqrt(5)) | ||||
| self.lora_B.weight.data.zero_() | ||||
|
|
||||
|
|
||||
| @ray.remote( | ||||
| runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2") | ||||
| ) # pragma: no cover | ||||
|
|
@@ -223,6 +238,19 @@ def __init__( | |||
|
|
||||
| full_state_dict = None | ||||
| model_state_dict_keys = None | ||||
|
|
||||
| # lora config | ||||
| lora_cfg = self.cfg["dtensor_cfg"].get("lora", None) | ||||
| self.peft_config = None | ||||
| self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"] | ||||
| # patch the init_lora_weights method to use the xavier initialization | ||||
| _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights | ||||
|
Comment on lines
+246
to
+247
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Global monkey-patch affects all LoRA instances in the process. Patching Consider applying the patch in a more scoped manner, or add a guard to prevent re-patching. +# Guard to prevent re-patching in multi-worker scenarios
+_LORA_INIT_PATCHED = False
+
# TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc
def _patched_init_lora_weights(self, init_method: str):
...
...
- # patch the init_lora_weights method to use the xavier initialization
- _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights
+ # patch the init_lora_weights method to use the xavier initialization
+ global _LORA_INIT_PATCHED
+ if not _LORA_INIT_PATCHED:
+ _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights
+ _LORA_INIT_PATCHED = True |
||||
| if self.lora_enabled: | ||||
| # Always use float32 since FSDP requires all parameters to be in the same dtype. | ||||
| # autocast should cast the weights to the correct dtype during the forward pass. | ||||
| cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"} | ||||
| self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype) | ||||
|
|
||||
| if self.rank == 0: | ||||
| print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") | ||||
| model = model_class.from_pretrained( | ||||
|
|
@@ -234,6 +262,9 @@ def __init__( | |||
| torch_dtype=str(model_config.torch_dtype), | ||||
| ) | ||||
|
|
||||
| if self.peft_config is not None: | ||||
| apply_lora_to_linear_modules(model, self.peft_config) | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like this function is called twice (on this line and line 291). Is that expected? |
||||
|
|
||||
| full_state_dict = model.state_dict() | ||||
| # Store the original model state dict keys before any parallelization | ||||
| model_state_dict_keys = list(full_state_dict.keys()) | ||||
|
|
@@ -256,6 +287,8 @@ def __init__( | |||
| trust_remote_code=True, | ||||
| torch_dtype=str(model_config.torch_dtype), | ||||
| ) | ||||
| if self.lora_enabled: | ||||
| apply_lora_to_linear_modules(self.model, self.peft_config) | ||||
|
|
||||
| if self.model.config.pad_token_id is None: | ||||
| self.model.config.pad_token_id = tokenizer.pad_token_id | ||||
|
|
@@ -1894,6 +1927,9 @@ def save_checkpoint( | |||
| "peft_config", | ||||
| } | ||||
| } | ||||
| if self.lora_enabled: | ||||
| checkpoint_kwargs["is_peft"] = True | ||||
| checkpoint_kwargs["peft_config"] = self.peft_config | ||||
|
|
||||
| save_checkpoint( | ||||
| model=self.model, | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -114,6 +114,9 @@ def __init__( | |
| if use_v2: | ||
| worker_builder_cls = "nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" | ||
| else: | ||
| assert config.get("lora", {}).get("enabled", False) is False, ( | ||
| "LoRA is not supported for DTensorPolicyWorker V1" | ||
| ) | ||
|
Comment on lines
+117
to
+119
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Incorrect config path for LoRA check. The assertion checks Apply this diff to fix the config path: - assert config.get("lora", {}).get("enabled", False) is False, (
+ assert config.get("dtensor_cfg", {}).get("lora", {}).get("enabled", False) is False, (
"LoRA is not supported for DTensorPolicyWorker V1"
)🤖 Prompt for AI Agents |
||
| worker_builder_cls = ( | ||
| "nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker" | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,46 @@ | ||||||
| #!/bin/bash | ||||||
|
|
||||||
| # clean up checkpoint directory on exit | ||||||
| trap "rm -rf /tmp/lora_sft_checkpoints" EXIT | ||||||
|
|
||||||
| SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) | ||||||
| PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) | ||||||
| # Mark the current repo as safe, since wandb fetches metadata about the repo | ||||||
| git config --global --add safe.directory $PROJECT_ROOT | ||||||
|
|
||||||
| set -eou pipefail | ||||||
|
|
||||||
| EXP_NAME=$(basename $0 .sh) | ||||||
| EXP_DIR=$SCRIPT_DIR/$EXP_NAME | ||||||
| LOG_DIR=$EXP_DIR/logs | ||||||
| JSON_METRICS=$EXP_DIR/metrics.json | ||||||
| RUN_LOG=$EXP_DIR/run.log | ||||||
| export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} | ||||||
|
|
||||||
| rm -rf $EXP_DIR $LOG_DIR | ||||||
| mkdir -p $EXP_DIR $LOG_DIR | ||||||
|
|
||||||
| cd $PROJECT_ROOT | ||||||
| uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ | ||||||
| $PROJECT_ROOT/examples/run_sft.py \ | ||||||
| policy.model_name=Qwen/Qwen3-0.6B \ | ||||||
| cluster.gpus_per_node=2 \ | ||||||
| sft.max_num_steps=3 \ | ||||||
| sft.val_batches=1 \ | ||||||
| sft.val_period=3 \ | ||||||
| policy.dtensor_cfg.lora.enabled=true \ | ||||||
| logger.tensorboard_enabled=true \ | ||||||
| logger.log_dir=$LOG_DIR \ | ||||||
| logger.wandb_enabled=false \ | ||||||
| logger.monitor_gpus=true \ | ||||||
| checkpointing.enabled=true \ | ||||||
| checkpointing.save_period=3 \ | ||||||
| checkpointing.checkpoint_dir=/tmp/lora_sft_checkpoints \ | ||||||
| $@ \ | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quote the array expansion to prevent word splitting. Static analysis correctly identifies that - $@ \
+ "$@" \📝 Committable suggestion
Suggested change
🧰 Tools🪛 Shellcheck (0.11.0)[error] 39-39: Double quote array expansions to avoid re-splitting elements. (SC2068) 🤖 Prompt for AI Agents
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I think this is a good catch!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @joyang-nv, thank you for confirming! I noticed something interesting though - there's a learning in the knowledge base (ID:
This learning conflicts with the current suggestion to quote 🧠 Learnings used |
||||||
| 2>&1 | tee $RUN_LOG | ||||||
|
|
||||||
| uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS | ||||||
|
|
||||||
| uv run tests/check_metrics.py $JSON_METRICS \ | ||||||
| 'data["train/loss"]["3"] < 5.9' | ||||||
|
|
||||||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please help document these flags?