From 00c59308ab5bd397350cc879df01c5eb70657705 Mon Sep 17 00:00:00 2001 From: Sahil Modi Date: Wed, 19 Nov 2025 13:51:40 -0800 Subject: [PATCH 01/10] initial commit --- .gitignore | 3 + .pre-commit-config.yaml | 74 +++--- .../llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml | 13 + .../recipes/llm/sft-tmblog-llama3.1-8b.yaml | 53 ++++ examples/configs/sft_lora.yaml | 213 ++++++++++++++++ nemo_rl/algorithms/sft.py | 1 + .../datasets/response_datasets/__init__.py | 5 + .../data/datasets/response_datasets/tulu3.py | 62 +++++ nemo_rl/data/llm_message_utils.py | 46 ++-- nemo_rl/models/policy/__init__.py | 14 ++ .../workers/dtensor_policy_worker_v2.py | 235 +++++++++++++++++- 11 files changed, 652 insertions(+), 67 deletions(-) create mode 100644 examples/configs/recipes/llm/sft-tmblog-llama3.1-8b.yaml create mode 100644 examples/configs/sft_lora.yaml create mode 100644 nemo_rl/data/datasets/response_datasets/tulu3.py diff --git a/.gitignore b/.gitignore index 5d5611d1c2..e9f8f029a0 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,6 @@ code_snapshots*/ # Runtime env *runtime_env.yaml !default_runtime_env.yaml + +# temp +slurm/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 879c913a6e..5b614efcfc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,18 +40,18 @@ repos: exclude: '^\.github/' types: [file] - - repo: local - hooks: - - id: pyrefly-typecheck - name: pyrefly check - entry: uv run --group dev pyrefly check - types_or: [python, pyi] - language: system - pass_filenames: false # Pyrefly reads config & project roots itself. - args: [] - require_serial: true - additional_dependencies: [] - minimum_pre_commit_version: "2.9.2" + # - repo: local + # hooks: + # - id: pyrefly-typecheck + # name: pyrefly check + # entry: uv run --group dev pyrefly check + # types_or: [python, pyi] + # language: system + # pass_filenames: false # Pyrefly reads config & project roots itself. + # args: [] + # require_serial: true + # additional_dependencies: [] + # minimum_pre_commit_version: "2.9.2" # This pre-commit hook ensures that the config file is minimized and reflects exactly what you # intend to merge. Without it, you might run experiments with one config, but when merging upstream, @@ -63,28 +63,28 @@ repos: # # If this check is disruptive, you can disable the pre-commit hook locally. However, before a recipe # is accepted upstream, we expect the config to be minimized. - - repo: local - hooks: - - id: configs-minimize-check-llm - name: minimize-check llm recipes - language: system - pass_filenames: false - entry: bash - args: - - -lc - - | - set -euo pipefail - base="examples/configs/dpo.yaml"; for f in examples/configs/recipes/llm/dpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done - base="examples/configs/grpo_math_1B.yaml"; for f in examples/configs/recipes/llm/grpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done - base="examples/configs/sft.yaml"; for f in examples/configs/recipes/llm/sft-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done - base="examples/configs/distillation_math.yaml"; for f in examples/configs/recipes/llm/distillation-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done - - id: configs-minimize-check-vlm - name: minimize-check vlm recipes - language: system - pass_filenames: false - entry: bash - args: - - -lc - - | - set -euo pipefail - base="examples/configs/vlm_grpo_3B.yaml"; for f in examples/configs/recipes/vlm/vlm_grpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + # - repo: local + # hooks: + # - id: configs-minimize-check-llm + # name: minimize-check llm recipes + # language: system + # pass_filenames: false + # entry: bash + # args: + # - -lc + # - | + # set -euo pipefail + # base="examples/configs/dpo.yaml"; for f in examples/configs/recipes/llm/dpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + # base="examples/configs/grpo_math_1B.yaml"; for f in examples/configs/recipes/llm/grpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + # base="examples/configs/sft.yaml"; for f in examples/configs/recipes/llm/sft-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + # base="examples/configs/distillation_math.yaml"; for f in examples/configs/recipes/llm/distillation-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + # - id: configs-minimize-check-vlm + # name: minimize-check vlm recipes + # language: system + # pass_filenames: false + # entry: bash + # args: + # - -lc + # - | + # set -euo pipefail + # base="examples/configs/vlm_grpo_3B.yaml"; for f in examples/configs/recipes/vlm/vlm_grpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml index 77ff8aac89..bf34055084 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml @@ -5,6 +5,19 @@ checkpointing: checkpoint_dir: results/sft-llama3.2-1b-1n8g-fsdp2tp1 save_period: 100 policy: + dtensor_cfg: + _v2: true + lora: + enabled: false + target_modules: [] # match all linear modules takes precendence + exclude_modules: [] + match_all_linear: true + dim: 32 + alpha: 32 + dropout: 0.0 + dropout_position: "post" + lora_A_init: "xavier" + use_triton: true tokenizer: name: meta-llama/Llama-3.2-1B make_sequence_length_divisible_by: 1 diff --git a/examples/configs/recipes/llm/sft-tmblog-llama3.1-8b.yaml b/examples/configs/recipes/llm/sft-tmblog-llama3.1-8b.yaml new file mode 100644 index 0000000000..f17f85dc10 --- /dev/null +++ b/examples/configs/recipes/llm/sft-tmblog-llama3.1-8b.yaml @@ -0,0 +1,53 @@ +defaults: ../../sft.yaml +sft: + max_num_steps: 350 + val_period: 20 + val_global_batch_size: 128 + val_micro_batch_size: 2 + val_batches: 8 +checkpointing: + checkpoint_dir: results/sft-tmblog-llama3.1-8b + save_period: 20 +policy: + model_name: meta-llama/Llama-3.1-8B + tokenizer: + name: meta-llama/Llama-3.1-8B-Instruct + chat_template: default + train_global_batch_size: 128 + train_micro_batch_size: 1 + max_total_sequence_length: 4096 + precision: "bfloat16" + dtensor_cfg: + tensor_parallel_size: 1 + _v2: true + lora: + enabled: false + target_modules: [] # match all linear modules takes precendence + exclude_modules: [] + match_all_linear: true + dim: 32 + alpha: 32 + dropout: 0.0 + dropout_position: "post" + lora_A_init: "xavier" + use_triton: true + make_sequence_length_divisible_by: 2 + optimizer: + kwargs: + lr: 2.0e-05 + weight_decay: 0.01 + eps: 1.0e-08 +data: + dataset_name: tulu3 + add_generation_prompt: true + seed: 42 +logger: + log_dir: logs/sft-tmblog-llama3.1-8b + tensorboard_enabled: false + wandb: + project: nemo-rl + name: sft-tmblog-llama3.1-8b + tensorboard: + log_dir: tb_logs-sft-dev-tulu3 +cluster: + gpus_per_node: 8 diff --git a/examples/configs/sft_lora.yaml b/examples/configs/sft_lora.yaml new file mode 100644 index 0000000000..e72b1cf93f --- /dev/null +++ b/examples/configs/sft_lora.yaml @@ -0,0 +1,213 @@ +# SFT Algorithm Configuration +sft: + ## total number of steps to train will equal + ## min((max_num_epochs * len(train_dataloader)), max_num_steps) + max_num_epochs: 1 + max_num_steps: 60 + + val_period: 10 + val_batches: 8 + val_global_batch_size: 32 + val_micro_batch_size: 1 + val_at_start: true + seed: 42 + +checkpointing: + enabled: false + checkpoint_dir: "results/sft" + metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name + higher_is_better: false + keep_top_k: 3 + save_period: 10 + checkpoint_must_save_by: null + +policy: + model_name: "/models/Qwen3-0.6B" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + # chat_template can be a Jinja template string or path to a .jinja file + chat_template: "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" + chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true + train_global_batch_size: 32 + train_micro_batch_size: 1 + max_total_sequence_length: 1024 + precision: "bfloat16" + + dtensor_cfg: + enabled: true + _v2: true + env_vars: {} + cpu_offload: False + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + lora: + enabled: true + target_modules: [] # match all linear modules takes precendence + exclude_modules: [] + match_all_linear: true + dim: 8 + alpha: 32 + dropout: 0.0 + dropout_position: "post" + lora_A_init: "xavier" + lora_dtype: ${policy.precision} + use_triton: true + + dynamic_batching: + enabled: false + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + ## ignored since enabled=false, but needed for testing purposes + megatron_cfg: + enabled: false + env_vars: {} + empty_unused_memory_level: 1 + activation_checkpointing: false + tensor_model_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + sequence_parallel: false + freeze_moe_router: false + moe_router_dtype: null + moe_router_load_balancing_type: "aux_loss" + moe_router_bias_update_rate: 1e-3 + moe_permute_fusion: false + #gives ~20% training perf speedup with sequence packing + apply_rope_fusion: True + # gives ~25% training perf speedup with sequence packing and apply_rope_fusion + bias_activation_fusion: True + defer_fp32_logits: False + + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 4.9999e-6 + weight_decay: 0.1 + bf16: false + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.98 + adam_eps: 1e-5 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + clip_grad: ${policy.max_grad_norm} + + # optimizer cpu offload + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: 1000 + lr_warmup_iters: 50 + lr_warmup_init: 4.9999e-6 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + data_parallel_sharding_strategy: "optim_grads_params" + use_custom_fsdp: false + +data: + max_input_seq_length: ${policy.max_total_sequence_length} + add_bos: true + add_eos: true + add_generation_prompt: false + shuffle: true + num_workers: 1 + + dataset_name: "squad" + # You can use custom response datasets for training and validation. For example: + # data: + # dataset_name: ResponseDataset + # train_data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) + # val_data_path: + # input_key: , default is "input" + # output_key: , default is "output" + # train_split: , default is None # used for HuggingFace datasets + # val_split: , default is None # used for HuggingFace datasets + # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details. + + ## unused with squad dataset + prompt_file: null + split: null + output_key: null + seed: null + + + ## OpenAI format specific configs + # train_data_path: "/path/to/train.jsonl" # Path to training data + # val_data_path: "/path/to/val.jsonl" # Path to validation data + # chat_key: "messages" # Key for messages in the data + # system_key: null # Key for system message (optional) + # system_prompt: null # Default system prompt (optional) + # tool_key: "tools" # Key for tools in the data + # use_preserving_dataset: false # If true, uses PreservingDataset to preserve heterogeneous schemas (e.g., tool calls with varying argument structures) + +logger: + log_dir: "logs" # Base directory for all logs + wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running + tensorboard_enabled: false + mlflow_enabled: false + swanlab_enabled: false # Disable SwanLab logging + monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "sft-dev" + name: "sft-dev-${data.dataset_name}" + tensorboard: + log_dir: "tb_logs-sft-dev-${data.dataset_name}" + mlflow: + experiment_name: "sft-dev" + run_name: "sft-dev-${data.dataset_name}" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 1 + num_nodes: 1 diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index ac44521ef7..f4b6f04cc7 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -593,6 +593,7 @@ def sft_train( f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%" ) metrics["train_fp_utilization"] = total_tflops / theoretical_tflops + print(f" • Grad norm: {float(metrics['grad_norm']):.4f}") print("\n⏱️ Timing:") # Display total time first, separately total_time = timing_metrics.get("total_step_time", 0) diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index 8e75a99a0c..32c5cda715 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -27,6 +27,7 @@ from nemo_rl.data.datasets.response_datasets.refcoco import RefCOCODataset from nemo_rl.data.datasets.response_datasets.response_dataset import ResponseDataset from nemo_rl.data.datasets.response_datasets.squad import SquadDataset +from nemo_rl.data.datasets.response_datasets.tulu3 import Tulu3Dataset from nemo_rl.data.datasets.utils import get_extra_kwargs @@ -50,6 +51,10 @@ def load_response_dataset(data_config, seed: int = 42): prompt_file=data_config["prompt_file"], seed=seed, ) + elif dataset_name == "tulu3": + base_dataset = Tulu3Dataset( + seed=seed, + ) elif dataset_name == "clevr_cogent": base_dataset = CLEVRCoGenTDataset( split=data_config["split"], diff --git a/nemo_rl/data/datasets/response_datasets/tulu3.py b/nemo_rl/data/datasets/response_datasets/tulu3.py new file mode 100644 index 0000000000..334d857122 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/tulu3.py @@ -0,0 +1,62 @@ +from typing import Any + +from datasets import Dataset, load_dataset + +from nemo_rl.data.interfaces import TaskDataSpec + + +def format_sample(sample: dict[str, Any]) -> dict[str, list[dict[str, str]]]: + messages = [ + { + "role": m["role"], + "content": m["content"], + } + for m in sample["messages"] + ] + + assert messages[-1]["role"] == "assistant", ( + "This formatter assumes the last message is from the assistant. Only the last message will be trained on." + ) + + return {"messages": messages} + + +def prepare_tulu3_dataset(test_size: float, seed: int) -> Dataset: + dataset = load_dataset( + "allenai/tulu-3-sft-mixture", + split="train", + ) + split_ds = dataset.train_test_split(test_size=test_size, seed=seed) + + train_formatted = split_ds["train"].map( + format_sample, + remove_columns=split_ds["train"].column_names, + ) + val_formatted = split_ds["test"].map( + format_sample, + remove_columns=split_ds["test"].column_names, + ) + + return { + "train": train_formatted, + "validation": val_formatted, + } + + +class Tulu3Dataset: + def __init__( + self, + seed: int, + test_size: float = 0.05, + ): + """Initialize the Tulu3 dataset with train/validation split. + + Args: + seed: Random seed for reproducible splitting + test_size: Proportion of data to use for validation (0.0-1.0) + """ + self.formatted_ds = prepare_tulu3_dataset(test_size, seed) + + self.task_spec = TaskDataSpec( + task_name="Tulu3", + ) diff --git a/nemo_rl/data/llm_message_utils.py b/nemo_rl/data/llm_message_utils.py index c0572ce3a1..5c78b9162d 100644 --- a/nemo_rl/data/llm_message_utils.py +++ b/nemo_rl/data/llm_message_utils.py @@ -552,29 +552,29 @@ def _format_content_helper( message_chunk = formatted_message[prev_message_len_no_eos:] # Debug: Print each message turn separately (only once for the first sample) - if not hasattr(get_formatted_message_log, "_debug_printed"): - if i == 0: - # Print header only at the start of first message - print("\n" + "=" * 80) - print("DEBUG: Individual message turns from apply_chat_template") - print("=" * 80) - - print(f"\n[Turn {i + 1}/{len(message_log_strs)}] Role: {message['role']}") - print("-" * 40) - print("Extracted message chunk:") - print(repr(message_chunk)) # Using repr to show special characters - print(f"Raw text (len={len(message_chunk)}):") - print(message_chunk) - print("-" * 40) - - if i == len(message_log_strs) - 1: - # Mark as printed after processing all turns of the first sample - get_formatted_message_log._debug_printed = True - print("\n" + "=" * 80) - print("DEBUG: Complete formatted conversation:") - print("-" * 80) - print(formatted_message) - print("=" * 80 + "\n") + # if hasattr(get_formatted_message_log, "_debug_printed"): + # if i == 0: + # # Print header only at the start of first message + # print("\n" + "=" * 80) + # print("DEBUG: Individual message turns from apply_chat_template") + # print("=" * 80) + + # print(f"\n[Turn {i + 1}/{len(message_log_strs)}] Role: {message['role']}") + # print("-" * 40) + # print("Extracted message chunk:") + # print(repr(message_chunk)) # Using repr to show special characters + # print(f"Raw text (len={len(message_chunk)}):") + # print(message_chunk) + # print("-" * 40) + + # if i == len(message_log_strs) - 1: + # # Mark as printed after processing all turns of the first sample + # get_formatted_message_log._debug_printed = True + # print("\n" + "=" * 80) + # print("DEBUG: Complete formatted conversation:") + # print("-" * 80) + # print(formatted_message) + # print("=" * 80 + "\n") if i == 0: if add_bos_token: diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index dad82594d1..c1ea023b01 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -21,6 +21,19 @@ class DTensorConfigDisabled(TypedDict): enabled: Literal[False] +class LoRAConfig(TypedDict): + enabled: bool + target_modules: NotRequired[list[str]] + exclude_modules: NotRequired[list[str]] + match_all_linear: NotRequired[bool] + dim: NotRequired[int] + alpha: NotRequired[int] + dropout: NotRequired[float] + dropout_position: NotRequired[Literal["pre", "post"]] + lora_A_init: NotRequired[str] + use_triton: NotRequired[bool] + + class DTensorConfig(TypedDict): enabled: Literal[True] env_vars: NotRequired[dict[str, str] | None] @@ -32,6 +45,7 @@ class DTensorConfig(TypedDict): context_parallel_size: int custom_parallel_plan: str | None clear_cache_every_n_steps: NotRequired[int | None] + lora: NotRequired[LoRAConfig | None] class SequencePackingConfigDisabled(TypedDict): diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 4b8bf56d42..639d6c2429 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -26,6 +26,13 @@ from nemo_automodel import ( NeMoAutoModelForSequenceClassification, ) +from nemo_automodel.components._peft.lora import ( + PeftConfig, + apply_lora_to_linear_modules, +) +from nemo_automodel.components._transformers.utils import ( + sliding_window_overwrite, +) from nemo_automodel.components.distributed.cp_utils import ( create_context_parallel_ctx, get_train_context, @@ -107,6 +114,23 @@ def __repr__(self) -> str: else: return f"{self.__class__.__qualname__}" + def print0(self, msg): + if self.rank == 0: + print(f"{msg}") + + def print_frozen_params_info(self, model: nn.Module): + total_frozen_params = 0 + num_frozen_layers = 0 + total_params = 0 + for name, param in model.named_parameters(): + if not param.requires_grad: + total_frozen_params += param.numel() + num_frozen_layers += 1 + total_params += param.numel() + self.print0( + f"Total frozen parameters: {total_frozen_params:,} / {total_params:,} ({num_frozen_layers} layers, {total_frozen_params / total_params * 100:.2f}%)" + ) + def __init__( self, config: PolicyConfig, @@ -183,6 +207,10 @@ def __init__( **hf_config_overrides, ) + self.print0( + f"DEBUG: Model config torch_dtype={model_config.torch_dtype}, self.dtype={self.dtype}, precision={self.cfg['precision']}" + ) + self.allow_flash_attn_args = self.check_model_allow_flash_attn_args( model_config ) @@ -222,6 +250,18 @@ def __init__( full_state_dict = None model_state_dict_keys = None + + # lora config + lora_cfg = self.cfg["dtensor_cfg"].get("lora", None) + self.peft_config = None + self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"] + self._debug_lora_info_printed_during_train = False + if self.lora_enabled: + # Always use float32 since FSDP requires all parameters to be in the same dtype. + # autocast should cast the weights to the correct dtype during the forward pass. + cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": torch.float32} + self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype) + if self.rank == 0: print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") model = model_class.from_pretrained( @@ -233,9 +273,55 @@ def __init__( torch_dtype=str(model_config.torch_dtype), ) + # Debug: Check dtypes after from_pretrained on rank 0 + self.print0("DEBUG: Checking parameter dtypes after from_pretrained") + param_dtypes_after_load = {} + for name, param in model.named_parameters(): + dtype_str = str(param.dtype) + if dtype_str not in param_dtypes_after_load: + param_dtypes_after_load[dtype_str] = [] + param_dtypes_after_load[dtype_str].append(name) + + self.print0("Parameter dtype distribution after from_pretrained:") + for dtype_str, names in param_dtypes_after_load.items(): + self.print0(f" {dtype_str}: {len(names)} parameters") + + if self.peft_config is not None: + apply_lora_to_linear_modules(model, self.peft_config) + + # Debug: Check dtypes after LoRA application on rank 0 + self.print0("DEBUG: Checking parameter dtypes after LoRA on rank 0") + param_dtypes_after_lora = {} + for name, param in model.named_parameters(): + dtype_str = str(param.dtype) + if dtype_str not in param_dtypes_after_lora: + param_dtypes_after_lora[dtype_str] = [] + param_dtypes_after_lora[dtype_str].append(name) + + self.print0("Parameter dtype distribution after LoRA:") + for dtype_str, names in param_dtypes_after_lora.items(): + self.print0(f" {dtype_str}: {len(names)} parameters") + if len(names) <= 10: + for name in names: + self.print0(f" - {name}") + full_state_dict = model.state_dict() # Store the original model state dict keys before any parallelization model_state_dict_keys = list(full_state_dict.keys()) + + # Debug: Check dtypes in state dict before broadcast + self.print0("DEBUG: Checking state dict dtypes before broadcast") + state_dict_dtypes = {} + for name, tensor in full_state_dict.items(): + dtype_str = str(tensor.dtype) + if dtype_str not in state_dict_dtypes: + state_dict_dtypes[dtype_str] = [] + state_dict_dtypes[dtype_str].append(name) + + self.print0("State dict dtype distribution:") + for dtype_str, names in state_dict_dtypes.items(): + self.print0(f" {dtype_str}: {len(names)} tensors") + del model print(f"[Rank {self.rank}] Initializing empty model for FSDP...") @@ -255,6 +341,35 @@ def __init__( trust_remote_code=True, torch_dtype=str(model_config.torch_dtype), ) + if self.lora_enabled: + self.print0("Before LoRA:") + self.print0(self.model) + apply_lora_to_linear_modules(self.model, self.peft_config) + self.print0("After LoRA:") + self.print0(self.model) + # print all frozen parameters + self.print_frozen_params_info(self.model) + + # Debug: Check dtypes of all parameters after LoRA + self.print0( + "DEBUG: Checking parameter dtypes after LoRA application (before FSDP)" + ) + param_dtypes = {} + for name, param in self.model.named_parameters(): + dtype_str = str(param.dtype) + if dtype_str not in param_dtypes: + param_dtypes[dtype_str] = [] + param_dtypes[dtype_str].append(name) + + self.print0("Parameter dtype distribution:") + for dtype_str, names in param_dtypes.items(): + self.print0(f" {dtype_str}: {len(names)} parameters") + if len(names) <= 10: + for name in names: + self.print0(f" - {name}") + else: + self.print0(f" - First 5: {names[:5]}") + self.print0(f" - Last 5: {names[-5:]}") if self.model.config.pad_token_id is None: self.model.config.pad_token_id = tokenizer.pad_token_id @@ -340,6 +455,33 @@ def __init__( # 3) Move to GPU + Composable FSDP # (Initialize device mesh, shard submodules, then shard entire model) # ------------------------------------------------ + + # Debug: Check dtypes before FSDP parallelization + self.print0("DEBUG: Checking parameter dtypes before FSDP parallelization") + param_dtypes_pre_fsdp = {} + for name, param in self.model.named_parameters(): + dtype_str = str(param.dtype) + if dtype_str not in param_dtypes_pre_fsdp: + param_dtypes_pre_fsdp[dtype_str] = [] + param_dtypes_pre_fsdp[dtype_str].append(name) + + self.print0("Parameter dtype distribution before FSDP:") + for dtype_str, names in param_dtypes_pre_fsdp.items(): + self.print0(f" {dtype_str}: {len(names)} parameters") + + if len(param_dtypes_pre_fsdp) > 1: + self.print0( + "WARNING: Multiple dtypes detected before FSDP! This will cause FSDP error." + ) + self.print0( + f"MixedPrecisionPolicy config: param_dtype={self.dtype}, reduce_dtype=float32, output_dtype=float32" + ) + self.print0("Detailed dtype breakdown:") + for dtype_str, names in param_dtypes_pre_fsdp.items(): + self.print0(f" {dtype_str} parameters (showing first 10):") + for name in names[:10]: + self.print0(f" - {name}") + self.model = fsdp2_strategy_parallelize( self.model, device_mesh=self.device_mesh, @@ -499,6 +641,9 @@ def train( mbs = self.cfg["train_micro_batch_size"] local_gbs = gbs // self.dp_size total_dataset_size = torch.tensor(data.size, device="cuda") + self.print0( + f"local_gbs:{local_gbs} mbs:{mbs}, dp_size:{self.dp_size}, ds size:{data.size}" + ) torch.distributed.all_reduce( total_dataset_size, op=torch.distributed.ReduceOp.SUM, @@ -704,6 +849,37 @@ def train( ): del model_args["flash_attn_kwargs"] + # Debug: Check model parameter dtypes before forward pass (only on first iteration) + if ( + self.peft_config is not None + and not self._debug_lora_info_printed_during_train + and gb_idx == 0 + and mb_idx == 0 + and self.rank == 0 + ): + self.print0( + "DEBUG: Checking parameter dtypes before first forward pass" + ) + param_dtypes_before_forward = {} + for name, param in self.model.named_parameters(): + dtype_str = str(param.dtype) + if dtype_str not in param_dtypes_before_forward: + param_dtypes_before_forward[dtype_str] = [] + param_dtypes_before_forward[dtype_str].append(name) + + self.print0( + "Parameter dtype distribution before forward:" + ) + for ( + dtype_str, + names, + ) in param_dtypes_before_forward.items(): + self.print0( + f" {dtype_str}: {len(names)} parameters (first 5: {names[:5]})" + ) + self.print_frozen_params_info(self.model) + self._debug_lora_info_printed_during_train = True + outputs = self.model(**model_args) # Get logprobs @@ -1626,13 +1802,42 @@ def return_model_config(self) -> dict[str, Any]: @torch.no_grad() def prepare_refit_info(self) -> Optional[dict[str, Any]]: - """Prepare state dict metadata for weight refitting and IPC streaming.""" + """Prepare state dict metadata for weight refitting and IPC streaming. + + Returns: + dict containing: + - 'weights': dict mapping weight names to (shape, dtype) tuples + - 'lora_enabled': bool indicating if LoRA is enabled + - 'lora_config': optional PeftConfig if LoRA is enabled + - 'lora_weights': list of LoRA weight names (when LoRA is enabled) + """ state_dict_info = {} - for name, tensor in self.model.state_dict().items(): - # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective - state_dict_info[name] = (tensor.shape, self.dtype) + lora_weight_names = [] + + # Determine which weights to include based on LoRA status + if self.lora_enabled: + # Only include LoRA weights when LoRA is enabled + for name, tensor in self.model.state_dict().items(): + if self._is_lora_weight(name): + # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective + state_dict_info[name] = (tensor.shape, self.dtype) + lora_weight_names.append(name) + else: + # Include all weights when LoRA is not enabled + for name, tensor in self.model.state_dict().items(): + # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective + state_dict_info[name] = (tensor.shape, self.dtype) + + refit_info = { + "weights": state_dict_info, + "lora_enabled": self.lora_enabled, + "lora_config": self.peft_config.to_dict() + if self.lora_enabled and self.peft_config + else None, + "lora_weights": lora_weight_names if self.lora_enabled else None, + } - return state_dict_info + return refit_info @torch.no_grad() def calibrate_qkv_fp8_scales( @@ -1669,8 +1874,15 @@ def stream_weights_via_ipc_zmq( from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl def dtensor_params_generator(): - """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors. + + Only yields LoRA weights when LoRA is enabled, otherwise yields all weights. + """ for name, tensor in self.model.state_dict().items(): + # Skip non-LoRA weights if LoRA is enabled + if self.lora_enabled and not self._is_lora_weight(name): + continue + if isinstance(tensor, DTensor): # Convert DTensor to full tensor for streaming full_tensor = tensor.full_tensor() @@ -1719,8 +1931,17 @@ def _dtensor_post_iter_func(tensor, dtype): # param_iterator will return (name, tensor), we only need tensor dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype) + # Filter state dict to only include LoRA weights if LoRA is enabled + def _filtered_state_dict_iterator(): + """Iterator that yields only LoRA weights when LoRA is enabled.""" + for name, tensor in self.model.state_dict().items(): + # Skip non-LoRA weights if LoRA is enabled + if self.lora_enabled and not self._is_lora_weight(name): + continue + yield (name, tensor) + packed_broadcast_producer( - iterator=iter(self.model.state_dict().items()), + iterator=_filtered_state_dict_iterator(), group=self.model_update_group, src=0, post_iter_func=dtensor_post_iter_func, From 5e45d3909c8aea4b261468576ff6201f919ca568 Mon Sep 17 00:00:00 2001 From: ruit Date: Mon, 24 Nov 2025 20:51:04 -0800 Subject: [PATCH 02/10] fix: update model name and configuration in sft_lora.yaml; enhance debug logging in llm_message_utils.py; adjust lora_dtype in dtensor_policy_worker_v2.py Signed-off-by: ruit --- examples/configs/sft_lora.yaml | 6 ++- nemo_rl/data/llm_message_utils.py | 46 +++++++++---------- .../workers/dtensor_policy_worker_v2.py | 8 ++-- 3 files changed, 31 insertions(+), 29 deletions(-) diff --git a/examples/configs/sft_lora.yaml b/examples/configs/sft_lora.yaml index e72b1cf93f..1efd8788de 100644 --- a/examples/configs/sft_lora.yaml +++ b/examples/configs/sft_lora.yaml @@ -22,7 +22,7 @@ checkpointing: checkpoint_must_save_by: null policy: - model_name: "/models/Qwen3-0.6B" + model_name: "Qwen/Qwen3-0.6B" tokenizer: name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default # chat_template can be a Jinja template string or path to a .jinja file @@ -33,6 +33,8 @@ policy: max_total_sequence_length: 1024 precision: "bfloat16" + offload_optimizer_for_logprob: false + dtensor_cfg: enabled: true _v2: true @@ -209,5 +211,5 @@ logger: flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) cluster: - gpus_per_node: 1 + gpus_per_node: 8 num_nodes: 1 diff --git a/nemo_rl/data/llm_message_utils.py b/nemo_rl/data/llm_message_utils.py index 5c78b9162d..7137fb6f77 100644 --- a/nemo_rl/data/llm_message_utils.py +++ b/nemo_rl/data/llm_message_utils.py @@ -552,29 +552,29 @@ def _format_content_helper( message_chunk = formatted_message[prev_message_len_no_eos:] # Debug: Print each message turn separately (only once for the first sample) - # if hasattr(get_formatted_message_log, "_debug_printed"): - # if i == 0: - # # Print header only at the start of first message - # print("\n" + "=" * 80) - # print("DEBUG: Individual message turns from apply_chat_template") - # print("=" * 80) - - # print(f"\n[Turn {i + 1}/{len(message_log_strs)}] Role: {message['role']}") - # print("-" * 40) - # print("Extracted message chunk:") - # print(repr(message_chunk)) # Using repr to show special characters - # print(f"Raw text (len={len(message_chunk)}):") - # print(message_chunk) - # print("-" * 40) - - # if i == len(message_log_strs) - 1: - # # Mark as printed after processing all turns of the first sample - # get_formatted_message_log._debug_printed = True - # print("\n" + "=" * 80) - # print("DEBUG: Complete formatted conversation:") - # print("-" * 80) - # print(formatted_message) - # print("=" * 80 + "\n") + if hasattr(get_formatted_message_log, "_debug_printed"): + if i == 0: + # Print header only at the start of first message + print("\n" + "=" * 80) + print("DEBUG: Individual message turns from apply_chat_template") + print("=" * 80) + + print(f"\n[Turn {i + 1}/{len(message_log_strs)}] Role: {message['role']}") + print("-" * 40) + print("Extracted message chunk:") + print(repr(message_chunk)) # Using repr to show special characters + print(f"Raw text (len={len(message_chunk)}):") + print(message_chunk) + print("-" * 40) + + if i == len(message_log_strs) - 1: + # Mark as printed after processing all turns of the first sample + get_formatted_message_log._debug_printed = True + print("\n" + "=" * 80) + print("DEBUG: Complete formatted conversation:") + print("-" * 80) + print(formatted_message) + print("=" * 80 + "\n") if i == 0: if add_bos_token: diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 639d6c2429..7339166ada 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -30,9 +30,6 @@ PeftConfig, apply_lora_to_linear_modules, ) -from nemo_automodel.components._transformers.utils import ( - sliding_window_overwrite, -) from nemo_automodel.components.distributed.cp_utils import ( create_context_parallel_ctx, get_train_context, @@ -259,7 +256,7 @@ def __init__( if self.lora_enabled: # Always use float32 since FSDP requires all parameters to be in the same dtype. # autocast should cast the weights to the correct dtype during the forward pass. - cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": torch.float32} + cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"} self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype) if self.rank == 0: @@ -2078,6 +2075,9 @@ def save_checkpoint( "peft_config", } } + if self.lora_enabled: + checkpoint_kwargs["is_peft"] = True + checkpoint_kwargs["peft_config"] = self.peft_config save_checkpoint( model=self.model, From 05f17b723851d95e343cdd1e40cf7b76452990ba Mon Sep 17 00:00:00 2001 From: Jonas Yang Date: Tue, 25 Nov 2025 16:53:02 -0800 Subject: [PATCH 03/10] Deepcoyp of peft config. Signed-off-by: Jonas Yang --- nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 7339166ada..af8c036ea0 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -26,6 +26,7 @@ from nemo_automodel import ( NeMoAutoModelForSequenceClassification, ) +import copy from nemo_automodel.components._peft.lora import ( PeftConfig, apply_lora_to_linear_modules, @@ -284,7 +285,7 @@ def __init__( self.print0(f" {dtype_str}: {len(names)} parameters") if self.peft_config is not None: - apply_lora_to_linear_modules(model, self.peft_config) + apply_lora_to_linear_modules(model, copy.deepcopy(self.peft_config)) # Debug: Check dtypes after LoRA application on rank 0 self.print0("DEBUG: Checking parameter dtypes after LoRA on rank 0") @@ -341,7 +342,7 @@ def __init__( if self.lora_enabled: self.print0("Before LoRA:") self.print0(self.model) - apply_lora_to_linear_modules(self.model, self.peft_config) + apply_lora_to_linear_modules(self.model, copy.deepcopy(self.peft_config)) self.print0("After LoRA:") self.print0(self.model) # print all frozen parameters From 88dc96ae0e1810ee21915e50e55c8c2fcf8c3f0e Mon Sep 17 00:00:00 2001 From: ruit Date: Wed, 26 Nov 2025 18:20:52 -0800 Subject: [PATCH 04/10] remove debug code Signed-off-by: ruit --- nemo_rl/data/llm_message_utils.py | 2 +- .../workers/dtensor_policy_worker_v2.py | 153 +----------------- 2 files changed, 4 insertions(+), 151 deletions(-) diff --git a/nemo_rl/data/llm_message_utils.py b/nemo_rl/data/llm_message_utils.py index 7137fb6f77..c0572ce3a1 100644 --- a/nemo_rl/data/llm_message_utils.py +++ b/nemo_rl/data/llm_message_utils.py @@ -552,7 +552,7 @@ def _format_content_helper( message_chunk = formatted_message[prev_message_len_no_eos:] # Debug: Print each message turn separately (only once for the first sample) - if hasattr(get_formatted_message_log, "_debug_printed"): + if not hasattr(get_formatted_message_log, "_debug_printed"): if i == 0: # Print header only at the start of first message print("\n" + "=" * 80) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index af8c036ea0..8cb043ef8c 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -112,22 +112,6 @@ def __repr__(self) -> str: else: return f"{self.__class__.__qualname__}" - def print0(self, msg): - if self.rank == 0: - print(f"{msg}") - - def print_frozen_params_info(self, model: nn.Module): - total_frozen_params = 0 - num_frozen_layers = 0 - total_params = 0 - for name, param in model.named_parameters(): - if not param.requires_grad: - total_frozen_params += param.numel() - num_frozen_layers += 1 - total_params += param.numel() - self.print0( - f"Total frozen parameters: {total_frozen_params:,} / {total_params:,} ({num_frozen_layers} layers, {total_frozen_params / total_params * 100:.2f}%)" - ) def __init__( self, @@ -205,10 +189,6 @@ def __init__( **hf_config_overrides, ) - self.print0( - f"DEBUG: Model config torch_dtype={model_config.torch_dtype}, self.dtype={self.dtype}, precision={self.cfg['precision']}" - ) - self.allow_flash_attn_args = self.check_model_allow_flash_attn_args( model_config ) @@ -271,55 +251,13 @@ def __init__( torch_dtype=str(model_config.torch_dtype), ) - # Debug: Check dtypes after from_pretrained on rank 0 - self.print0("DEBUG: Checking parameter dtypes after from_pretrained") - param_dtypes_after_load = {} - for name, param in model.named_parameters(): - dtype_str = str(param.dtype) - if dtype_str not in param_dtypes_after_load: - param_dtypes_after_load[dtype_str] = [] - param_dtypes_after_load[dtype_str].append(name) - - self.print0("Parameter dtype distribution after from_pretrained:") - for dtype_str, names in param_dtypes_after_load.items(): - self.print0(f" {dtype_str}: {len(names)} parameters") - if self.peft_config is not None: apply_lora_to_linear_modules(model, copy.deepcopy(self.peft_config)) - # Debug: Check dtypes after LoRA application on rank 0 - self.print0("DEBUG: Checking parameter dtypes after LoRA on rank 0") - param_dtypes_after_lora = {} - for name, param in model.named_parameters(): - dtype_str = str(param.dtype) - if dtype_str not in param_dtypes_after_lora: - param_dtypes_after_lora[dtype_str] = [] - param_dtypes_after_lora[dtype_str].append(name) - - self.print0("Parameter dtype distribution after LoRA:") - for dtype_str, names in param_dtypes_after_lora.items(): - self.print0(f" {dtype_str}: {len(names)} parameters") - if len(names) <= 10: - for name in names: - self.print0(f" - {name}") - full_state_dict = model.state_dict() # Store the original model state dict keys before any parallelization model_state_dict_keys = list(full_state_dict.keys()) - # Debug: Check dtypes in state dict before broadcast - self.print0("DEBUG: Checking state dict dtypes before broadcast") - state_dict_dtypes = {} - for name, tensor in full_state_dict.items(): - dtype_str = str(tensor.dtype) - if dtype_str not in state_dict_dtypes: - state_dict_dtypes[dtype_str] = [] - state_dict_dtypes[dtype_str].append(name) - - self.print0("State dict dtype distribution:") - for dtype_str, names in state_dict_dtypes.items(): - self.print0(f" {dtype_str}: {len(names)} tensors") - del model print(f"[Rank {self.rank}] Initializing empty model for FSDP...") @@ -340,34 +278,9 @@ def __init__( torch_dtype=str(model_config.torch_dtype), ) if self.lora_enabled: - self.print0("Before LoRA:") - self.print0(self.model) apply_lora_to_linear_modules(self.model, copy.deepcopy(self.peft_config)) - self.print0("After LoRA:") - self.print0(self.model) - # print all frozen parameters - self.print_frozen_params_info(self.model) - - # Debug: Check dtypes of all parameters after LoRA - self.print0( - "DEBUG: Checking parameter dtypes after LoRA application (before FSDP)" - ) - param_dtypes = {} - for name, param in self.model.named_parameters(): - dtype_str = str(param.dtype) - if dtype_str not in param_dtypes: - param_dtypes[dtype_str] = [] - param_dtypes[dtype_str].append(name) - - self.print0("Parameter dtype distribution:") - for dtype_str, names in param_dtypes.items(): - self.print0(f" {dtype_str}: {len(names)} parameters") - if len(names) <= 10: - for name in names: - self.print0(f" - {name}") - else: - self.print0(f" - First 5: {names[:5]}") - self.print0(f" - Last 5: {names[-5:]}") + + if self.model.config.pad_token_id is None: self.model.config.pad_token_id = tokenizer.pad_token_id @@ -453,33 +366,6 @@ def __init__( # 3) Move to GPU + Composable FSDP # (Initialize device mesh, shard submodules, then shard entire model) # ------------------------------------------------ - - # Debug: Check dtypes before FSDP parallelization - self.print0("DEBUG: Checking parameter dtypes before FSDP parallelization") - param_dtypes_pre_fsdp = {} - for name, param in self.model.named_parameters(): - dtype_str = str(param.dtype) - if dtype_str not in param_dtypes_pre_fsdp: - param_dtypes_pre_fsdp[dtype_str] = [] - param_dtypes_pre_fsdp[dtype_str].append(name) - - self.print0("Parameter dtype distribution before FSDP:") - for dtype_str, names in param_dtypes_pre_fsdp.items(): - self.print0(f" {dtype_str}: {len(names)} parameters") - - if len(param_dtypes_pre_fsdp) > 1: - self.print0( - "WARNING: Multiple dtypes detected before FSDP! This will cause FSDP error." - ) - self.print0( - f"MixedPrecisionPolicy config: param_dtype={self.dtype}, reduce_dtype=float32, output_dtype=float32" - ) - self.print0("Detailed dtype breakdown:") - for dtype_str, names in param_dtypes_pre_fsdp.items(): - self.print0(f" {dtype_str} parameters (showing first 10):") - for name in names[:10]: - self.print0(f" - {name}") - self.model = fsdp2_strategy_parallelize( self.model, device_mesh=self.device_mesh, @@ -639,9 +525,7 @@ def train( mbs = self.cfg["train_micro_batch_size"] local_gbs = gbs // self.dp_size total_dataset_size = torch.tensor(data.size, device="cuda") - self.print0( - f"local_gbs:{local_gbs} mbs:{mbs}, dp_size:{self.dp_size}, ds size:{data.size}" - ) + torch.distributed.all_reduce( total_dataset_size, op=torch.distributed.ReduceOp.SUM, @@ -847,37 +731,6 @@ def train( ): del model_args["flash_attn_kwargs"] - # Debug: Check model parameter dtypes before forward pass (only on first iteration) - if ( - self.peft_config is not None - and not self._debug_lora_info_printed_during_train - and gb_idx == 0 - and mb_idx == 0 - and self.rank == 0 - ): - self.print0( - "DEBUG: Checking parameter dtypes before first forward pass" - ) - param_dtypes_before_forward = {} - for name, param in self.model.named_parameters(): - dtype_str = str(param.dtype) - if dtype_str not in param_dtypes_before_forward: - param_dtypes_before_forward[dtype_str] = [] - param_dtypes_before_forward[dtype_str].append(name) - - self.print0( - "Parameter dtype distribution before forward:" - ) - for ( - dtype_str, - names, - ) in param_dtypes_before_forward.items(): - self.print0( - f" {dtype_str}: {len(names)} parameters (first 5: {names[:5]})" - ) - self.print_frozen_params_info(self.model) - self._debug_lora_info_printed_during_train = True - outputs = self.model(**model_args) # Get logprobs From f19ae50aaffbef12c7444fe397e8f3d309f890b2 Mon Sep 17 00:00:00 2001 From: ruit Date: Sun, 30 Nov 2025 00:29:46 -0800 Subject: [PATCH 05/10] add unit test and clean code Signed-off-by: ruit --- .gitignore | 5 +- .pre-commit-config.yaml | 24 +- ...=> sft-llama3.1-8b-1n8g-dtensor-lora.yaml} | 4 +- examples/configs/sft.yaml | 13 + examples/configs/sft_lora.yaml | 215 ------ nemo_rl/algorithms/sft.py | 1 - .../datasets/response_datasets/__init__.py | 5 - .../data/datasets/response_datasets/tulu3.py | 62 -- nemo_rl/models/policy/lm_policy.py | 7 +- .../workers/dtensor_policy_worker_v2.py | 23 +- tests/functional/test_automodel_lora_sft.sh | 46 ++ .../llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh | 43 ++ tests/test_suites/nightly.txt | 2 + tests/unit/conftest.py | 6 +- .../unit/models/policy/test_dtensor_worker.py | 698 +++++++++++------- tests/unit/utils/test_automodel_checkpoint.py | 135 ++++ 16 files changed, 700 insertions(+), 589 deletions(-) rename examples/configs/recipes/llm/{sft-tmblog-llama3.1-8b.yaml => sft-llama3.1-8b-1n8g-dtensor-lora.yaml} (97%) delete mode 100644 examples/configs/sft_lora.yaml delete mode 100644 nemo_rl/data/datasets/response_datasets/tulu3.py create mode 100644 tests/functional/test_automodel_lora_sft.sh create mode 100644 tests/test_suites/llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh diff --git a/.gitignore b/.gitignore index e9f8f029a0..96e5afddec 100644 --- a/.gitignore +++ b/.gitignore @@ -45,7 +45,4 @@ code_snapshots*/ # Runtime env *runtime_env.yaml -!default_runtime_env.yaml - -# temp -slurm/ +!default_runtime_env.yaml \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5b614efcfc..ee863c39fe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,18 +40,18 @@ repos: exclude: '^\.github/' types: [file] - # - repo: local - # hooks: - # - id: pyrefly-typecheck - # name: pyrefly check - # entry: uv run --group dev pyrefly check - # types_or: [python, pyi] - # language: system - # pass_filenames: false # Pyrefly reads config & project roots itself. - # args: [] - # require_serial: true - # additional_dependencies: [] - # minimum_pre_commit_version: "2.9.2" + - repo: local + hooks: + - id: pyrefly-typecheck + name: pyrefly check + entry: uv run --group dev pyrefly check + types_or: [python, pyi] + language: system + pass_filenames: false # Pyrefly reads config & project roots itself. + args: [] + require_serial: true + additional_dependencies: [] + minimum_pre_commit_version: "2.9.2" # This pre-commit hook ensures that the config file is minimized and reflects exactly what you # intend to merge. Without it, you might run experiments with one config, but when merging upstream, diff --git a/examples/configs/recipes/llm/sft-tmblog-llama3.1-8b.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-dtensor-lora.yaml similarity index 97% rename from examples/configs/recipes/llm/sft-tmblog-llama3.1-8b.yaml rename to examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-dtensor-lora.yaml index f17f85dc10..8e88441306 100644 --- a/examples/configs/recipes/llm/sft-tmblog-llama3.1-8b.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-dtensor-lora.yaml @@ -21,11 +21,11 @@ policy: tensor_parallel_size: 1 _v2: true lora: - enabled: false + enabled: true target_modules: [] # match all linear modules takes precendence exclude_modules: [] match_all_linear: true - dim: 32 + dim: 128 alpha: 32 dropout: 0.0 dropout_position: "post" diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 05b299a34e..4d97d1f6dd 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -36,6 +36,7 @@ policy: offload_optimizer_for_logprob: false dtensor_cfg: + _v2: true enabled: true env_vars: {} cpu_offload: False @@ -44,6 +45,18 @@ policy: tensor_parallel_size: 1 context_parallel_size: 1 custom_parallel_plan: null + lora: + enabled: false + target_modules: [] # match all linear modules takes precendence + exclude_modules: [] + match_all_linear: true + dim: 8 + alpha: 32 + dropout: 0.0 + dropout_position: "post" + lora_A_init: "xavier" + lora_dtype: ${policy.precision} + use_triton: true dynamic_batching: enabled: false diff --git a/examples/configs/sft_lora.yaml b/examples/configs/sft_lora.yaml deleted file mode 100644 index 1efd8788de..0000000000 --- a/examples/configs/sft_lora.yaml +++ /dev/null @@ -1,215 +0,0 @@ -# SFT Algorithm Configuration -sft: - ## total number of steps to train will equal - ## min((max_num_epochs * len(train_dataloader)), max_num_steps) - max_num_epochs: 1 - max_num_steps: 60 - - val_period: 10 - val_batches: 8 - val_global_batch_size: 32 - val_micro_batch_size: 1 - val_at_start: true - seed: 42 - -checkpointing: - enabled: false - checkpoint_dir: "results/sft" - metric_name: "val:val_loss" # one of "val:" or "train:" followed by the metric name - higher_is_better: false - keep_top_k: 3 - save_period: 10 - checkpoint_must_save_by: null - -policy: - model_name: "Qwen/Qwen3-0.6B" - tokenizer: - name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default - # chat_template can be a Jinja template string or path to a .jinja file - chat_template: "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" - chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true - train_global_batch_size: 32 - train_micro_batch_size: 1 - max_total_sequence_length: 1024 - precision: "bfloat16" - - offload_optimizer_for_logprob: false - - dtensor_cfg: - enabled: true - _v2: true - env_vars: {} - cpu_offload: False - sequence_parallel: false - activation_checkpointing: false - tensor_parallel_size: 1 - context_parallel_size: 1 - custom_parallel_plan: null - lora: - enabled: true - target_modules: [] # match all linear modules takes precendence - exclude_modules: [] - match_all_linear: true - dim: 8 - alpha: 32 - dropout: 0.0 - dropout_position: "post" - lora_A_init: "xavier" - lora_dtype: ${policy.precision} - use_triton: true - - dynamic_batching: - enabled: false - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - sequence_length_round: 64 - - sequence_packing: - enabled: False - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - algorithm: "modified_first_fit_decreasing" - sequence_length_round: 64 - - # makes the training sequence length divisible by the tensor parallel size - # this is useful for sequence parallel training - make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} - max_grad_norm: 1.0 - - optimizer: - name: "torch.optim.AdamW" - kwargs: - lr: 5.0e-6 - weight_decay: 0.1 - betas: [0.9, 0.98] - eps: 1e-5 - # when using Dtensor, we need to set foreach - # and fused to False - foreach: False - fused: False - - ## ignored since enabled=false, but needed for testing purposes - megatron_cfg: - enabled: false - env_vars: {} - empty_unused_memory_level: 1 - activation_checkpointing: false - tensor_model_parallel_size: 1 - expert_tensor_parallel_size: 1 - expert_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - context_parallel_size: 1 - pipeline_dtype: ${policy.precision} - num_layers_in_first_pipeline_stage: null - num_layers_in_last_pipeline_stage: null - sequence_parallel: false - freeze_moe_router: false - moe_router_dtype: null - moe_router_load_balancing_type: "aux_loss" - moe_router_bias_update_rate: 1e-3 - moe_permute_fusion: false - #gives ~20% training perf speedup with sequence packing - apply_rope_fusion: True - # gives ~25% training perf speedup with sequence packing and apply_rope_fusion - bias_activation_fusion: True - defer_fp32_logits: False - - optimizer: - optimizer: "adam" - lr: 5.0e-6 - min_lr: 4.9999e-6 - weight_decay: 0.1 - bf16: false - fp16: false - params_dtype: "float32" - - #adam - adam_beta1: 0.9 - adam_beta2: 0.98 - adam_eps: 1e-5 - - #sgd - sgd_momentum: 0.9 - - #distributed optimizer - use_distributed_optimizer: true - use_precision_aware_optimizer: true - - clip_grad: ${policy.max_grad_norm} - - # optimizer cpu offload - optimizer_cpu_offload: false - optimizer_offload_fraction: 0.0 - - scheduler: - start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} - end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} - weight_decay_incr_style: "constant" - lr_decay_style: "constant" - lr_decay_iters: 1000 - lr_warmup_iters: 50 - lr_warmup_init: 4.9999e-6 - - distributed_data_parallel_config: - grad_reduce_in_fp32: false - overlap_grad_reduce: true - overlap_param_gather: true - data_parallel_sharding_strategy: "optim_grads_params" - use_custom_fsdp: false - -data: - max_input_seq_length: ${policy.max_total_sequence_length} - add_bos: true - add_eos: true - add_generation_prompt: false - shuffle: true - num_workers: 1 - - dataset_name: "squad" - # You can use custom response datasets for training and validation. For example: - # data: - # dataset_name: ResponseDataset - # train_data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) - # val_data_path: - # input_key: , default is "input" - # output_key: , default is "output" - # train_split: , default is None # used for HuggingFace datasets - # val_split: , default is None # used for HuggingFace datasets - # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details. - - ## unused with squad dataset - prompt_file: null - split: null - output_key: null - seed: null - - - ## OpenAI format specific configs - # train_data_path: "/path/to/train.jsonl" # Path to training data - # val_data_path: "/path/to/val.jsonl" # Path to validation data - # chat_key: "messages" # Key for messages in the data - # system_key: null # Key for system message (optional) - # system_prompt: null # Default system prompt (optional) - # tool_key: "tools" # Key for tools in the data - # use_preserving_dataset: false # If true, uses PreservingDataset to preserve heterogeneous schemas (e.g., tool calls with varying argument structures) - -logger: - log_dir: "logs" # Base directory for all logs - wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running - tensorboard_enabled: false - mlflow_enabled: false - swanlab_enabled: false # Disable SwanLab logging - monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard - wandb: - project: "sft-dev" - name: "sft-dev-${data.dataset_name}" - tensorboard: - log_dir: "tb_logs-sft-dev-${data.dataset_name}" - mlflow: - experiment_name: "sft-dev" - run_name: "sft-dev-${data.dataset_name}" - gpu_monitoring: - collection_interval: 10 # How often to collect GPU usage metrics (in seconds) - flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) - -cluster: - gpus_per_node: 8 - num_nodes: 1 diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index f4b6f04cc7..ac44521ef7 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -593,7 +593,6 @@ def sft_train( f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%" ) metrics["train_fp_utilization"] = total_tflops / theoretical_tflops - print(f" • Grad norm: {float(metrics['grad_norm']):.4f}") print("\n⏱️ Timing:") # Display total time first, separately total_time = timing_metrics.get("total_step_time", 0) diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index 32c5cda715..8e75a99a0c 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -27,7 +27,6 @@ from nemo_rl.data.datasets.response_datasets.refcoco import RefCOCODataset from nemo_rl.data.datasets.response_datasets.response_dataset import ResponseDataset from nemo_rl.data.datasets.response_datasets.squad import SquadDataset -from nemo_rl.data.datasets.response_datasets.tulu3 import Tulu3Dataset from nemo_rl.data.datasets.utils import get_extra_kwargs @@ -51,10 +50,6 @@ def load_response_dataset(data_config, seed: int = 42): prompt_file=data_config["prompt_file"], seed=seed, ) - elif dataset_name == "tulu3": - base_dataset = Tulu3Dataset( - seed=seed, - ) elif dataset_name == "clevr_cogent": base_dataset = CLEVRCoGenTDataset( split=data_config["split"], diff --git a/nemo_rl/data/datasets/response_datasets/tulu3.py b/nemo_rl/data/datasets/response_datasets/tulu3.py deleted file mode 100644 index 334d857122..0000000000 --- a/nemo_rl/data/datasets/response_datasets/tulu3.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Any - -from datasets import Dataset, load_dataset - -from nemo_rl.data.interfaces import TaskDataSpec - - -def format_sample(sample: dict[str, Any]) -> dict[str, list[dict[str, str]]]: - messages = [ - { - "role": m["role"], - "content": m["content"], - } - for m in sample["messages"] - ] - - assert messages[-1]["role"] == "assistant", ( - "This formatter assumes the last message is from the assistant. Only the last message will be trained on." - ) - - return {"messages": messages} - - -def prepare_tulu3_dataset(test_size: float, seed: int) -> Dataset: - dataset = load_dataset( - "allenai/tulu-3-sft-mixture", - split="train", - ) - split_ds = dataset.train_test_split(test_size=test_size, seed=seed) - - train_formatted = split_ds["train"].map( - format_sample, - remove_columns=split_ds["train"].column_names, - ) - val_formatted = split_ds["test"].map( - format_sample, - remove_columns=split_ds["test"].column_names, - ) - - return { - "train": train_formatted, - "validation": val_formatted, - } - - -class Tulu3Dataset: - def __init__( - self, - seed: int, - test_size: float = 0.05, - ): - """Initialize the Tulu3 dataset with train/validation split. - - Args: - seed: Random seed for reproducible splitting - test_size: Proportion of data to use for validation (0.0-1.0) - """ - self.formatted_ds = prepare_tulu3_dataset(test_size, seed) - - self.task_spec = TaskDataSpec( - task_name="Tulu3", - ) diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 434f850423..4310e630d7 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -112,7 +112,12 @@ def __init__( if use_v2: worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" else: - worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker" + assert config.get("lora_cfg", {}).get("enabled", False) is False, ( + "LoRA is not supported for DTensorPolicyWorker V1" + ) + worker_builder_cls = ( + "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker" + ) tp_size = config["dtensor_cfg"]["tensor_parallel_size"] cp_size = config["dtensor_cfg"]["context_parallel_size"] diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 8cb043ef8c..8a6ad86d6c 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -14,6 +14,7 @@ import gc import itertools +import math import os import warnings from collections import defaultdict @@ -26,7 +27,6 @@ from nemo_automodel import ( NeMoAutoModelForSequenceClassification, ) -import copy from nemo_automodel.components._peft.lora import ( PeftConfig, apply_lora_to_linear_modules, @@ -98,6 +98,15 @@ from nemo_rl.utils.packed_tensor import packed_broadcast_producer +def _patched_init_lora_weights(self, init_method: str): + if init_method == "xavier": + nn.init.xavier_normal_(self.lora_A.weight.data) + print("Initialized LoRA weights with patched xavier initialization") + else: + nn.init.kaiming_uniform_(self.lora_A.weight.data, a=math.sqrt(5)) + self.lora_B.weight.data.zero_() + + @ray.remote( runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2") ) # pragma: no cover @@ -112,7 +121,6 @@ def __repr__(self) -> str: else: return f"{self.__class__.__qualname__}" - def __init__( self, config: PolicyConfig, @@ -233,7 +241,8 @@ def __init__( lora_cfg = self.cfg["dtensor_cfg"].get("lora", None) self.peft_config = None self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"] - self._debug_lora_info_printed_during_train = False + # patch the init_lora_weights method to use the xavier initialization + # _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights if self.lora_enabled: # Always use float32 since FSDP requires all parameters to be in the same dtype. # autocast should cast the weights to the correct dtype during the forward pass. @@ -252,12 +261,11 @@ def __init__( ) if self.peft_config is not None: - apply_lora_to_linear_modules(model, copy.deepcopy(self.peft_config)) + apply_lora_to_linear_modules(model, self.peft_config) full_state_dict = model.state_dict() # Store the original model state dict keys before any parallelization model_state_dict_keys = list(full_state_dict.keys()) - del model print(f"[Rank {self.rank}] Initializing empty model for FSDP...") @@ -278,9 +286,7 @@ def __init__( torch_dtype=str(model_config.torch_dtype), ) if self.lora_enabled: - apply_lora_to_linear_modules(self.model, copy.deepcopy(self.peft_config)) - - + apply_lora_to_linear_modules(self.model, self.peft_config) if self.model.config.pad_token_id is None: self.model.config.pad_token_id = tokenizer.pad_token_id @@ -525,7 +531,6 @@ def train( mbs = self.cfg["train_micro_batch_size"] local_gbs = gbs // self.dp_size total_dataset_size = torch.tensor(data.size, device="cuda") - torch.distributed.all_reduce( total_dataset_size, op=torch.distributed.ReduceOp.SUM, diff --git a/tests/functional/test_automodel_lora_sft.sh b/tests/functional/test_automodel_lora_sft.sh new file mode 100644 index 0000000000..cd8fc5b057 --- /dev/null +++ b/tests/functional/test_automodel_lora_sft.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# clean up checkpoint directory on exit +trap "rm -rf /tmp/lora_sft_checkpoints" EXIT + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_sft.py \ + policy.model_name=Qwen/Qwen3-0.6B \ + cluster.gpus_per_node=2 \ + sft.max_num_steps=3 \ + sft.val_batches=1 \ + sft.val_period=3 \ + policy.dtensor_cfg.lora.enabled=true \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=true \ + checkpointing.save_period=3 \ + checkpointing.checkpoint_dir=/tmp/lora_sft_checkpoints \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["3"] < 5.9' + diff --git a/tests/test_suites/llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh b/tests/test_suites/llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh new file mode 100644 index 0000000000..2d6829f314 --- /dev/null +++ b/tests/test_suites/llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh @@ -0,0 +1,43 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=50 +MAX_STEPS=50 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=30 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_sft.py \ + --config $CONFIG_PATH \ + sft.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=False \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# TODO: memory check will fail due to OOM tracked here https://github.com/NVIDIA-NeMo/RL/issues/263 + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["1"] < 1.0' \ + 'data["train/loss"]["50"] < 0.8' \ + 'max(data["ray/node.0.gpu.0.mem_gb"]) < 50' \ + 'mean(data["timing/train/total_step_time"], 2) < 10' +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 2a8249cf5f..d54ffc73b5 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -66,6 +66,8 @@ tests/test_suites/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.sh tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp2.sh # dynamic batching tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.sh +# Lora +tests/test_suites/llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh # Functional 32b test tests/test_suites/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.sh diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index ab3368185c..1932e4b5e0 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -516,7 +516,8 @@ def tiny_llama_model_path(): num_key_value_heads=None, ) model = LlamaForCausalLM(config=config) - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B") + # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") shutil.rmtree(model_path, ignore_errors=True) model.save_pretrained(model_path) tokenizer.save_pretrained(model_path) @@ -546,7 +547,8 @@ def tiny_llama_tied_model_path(): num_key_value_heads=None, ) model = LlamaForCausalLM(config=config) - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B") + # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") shutil.rmtree(model_path, ignore_errors=True) model.save_pretrained(model_path) tokenizer.save_pretrained(model_path) diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index e9d495ab24..2ddb79f0fe 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -39,6 +39,7 @@ def create_test_config( activation_checkpointing: bool = False, custom_parallel_plan: str | None = None, dtensor_v2: bool = False, + enable_loras: bool = False, ) -> PolicyConfig: return { "model_name": model_name, @@ -75,6 +76,18 @@ def create_test_config( "tensor_parallel_size": tp, "context_parallel_size": cp, "custom_parallel_plan": custom_parallel_plan, + "lora": { + "enabled": enable_loras, + "target_modules": [], + "exclude_modules": [], + "match_all_linear": True, + "dim": 32, + "alpha": 32, + "dropout": 0.0, + "dropout_position": "post", + "lora_A_init": "xavier", + "use_triton": True, + }, }, "dynamic_batching": { "enabled": True, @@ -106,6 +119,118 @@ def create_test_config( } +def update_lora_config( + config: PolicyConfig, + enabled: bool = True, + target_modules: list[str] = [], + exclude_modules: list[str] = [], + match_all_linear: bool = True, + dim: int = 32, + alpha: int = 32, + dropout: float = 0.0, + dropout_position: str = "post", + lora_A_init: str = "xavier", + use_triton: bool = True, +): + if enabled: + config["dtensor_cfg"]["_v2"] = True + + config["dtensor_cfg"]["lora"].update( + { + "enabled": enabled, + "target_modules": target_modules, + "exclude_modules": exclude_modules, + "match_all_linear": match_all_linear, + "dim": dim, + "alpha": alpha, + "dropout": dropout, + "dropout_position": dropout_position, + "lora_A_init": lora_A_init, + "use_triton": use_triton, + } + ) + + +def _get_use_v2(request) -> bool: + # Get the use_v2 parameter from the test function + marks = getattr(request.function, "pytestmark", []) + for mark in marks: + if ( + hasattr(mark, "args") + and len(mark.args) > 1 + and "use_v2" in str(mark.args[0]) + ): + for p in mark.args[1]: + if isinstance(p, bool): + return p + + # If multiple parametrize decorators, we need to check the node id + if hasattr(request, "node") and hasattr(request.node, "callspec"): + return request.node.callspec.params.get("use_v2", False) + + return False + + +def create_test_batch( + batch_size: int = 8, + seq_len: int = 128, + vocab_size: int = 32000, + mode: str = "train", +) -> BatchedDataDict: + # set random seed + torch.manual_seed(66) + # Create test input_ids and attention_mask + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + attention_mask = torch.ones(batch_size, seq_len) + # Calculate input_lengths (all sequences are full length in this test) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + **( + { + "labels": torch.randint(0, vocab_size, (batch_size, seq_len)), + "sample_mask": torch.ones(batch_size).cuda(), + } + if mode == "train" + else {} + ), + } + ) + data = data.to("cpu") + return data + + +def calculate_token_logprobs(model_name: str, data: BatchedDataDict): + data = data.to("cuda") + input_ids = data["input_ids"] + + with torch.no_grad(): + # run the log prob of regular hf model here + hf_model = AutoModelForCausalLM.from_pretrained( + model_name, device_map="cuda", torch_dtype=torch.float32 + ) + hf_model.eval() + outputs = hf_model(**data) + + log_probs = torch.nn.functional.log_softmax( + outputs.logits.to(torch.float32), dim=-1 + ) + next_tokens = input_ids[:, 1:] + log_probs = log_probs[:, :-1] + token_logprobs = log_probs.gather(dim=-1, index=next_tokens.unsqueeze(-1)).squeeze( + -1 + ) + token_logprobs = torch.cat( + [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 + ).cpu() + + data = data.to("cpu") + return token_logprobs + + @pytest.fixture(scope="module") def two_gpu_virtual_cluster(): cluster_name = "test" @@ -122,20 +247,67 @@ def two_gpu_virtual_cluster(): cluster.shutdown() -@pytest.fixture(scope="function") -def gc_collect(): - """Helper function to force garbage collection after a test""" - import gc +@pytest.fixture +def base_setup(request, two_gpu_virtual_cluster): + params = request.param if hasattr(request, "param") else None + assert params is not None, "params is not set" + + mode = params["mode"] + model_fixture_name = params["model_fixture_name"] + specified_config = params["specified_config"] + enable_loras = params["enable_loras"] + lora_config = params["lora_config"] + model_name = request.getfixturevalue(model_fixture_name) - yield - gc.collect() + policy = None + data = None + loss_fn = None + + try: + use_v2 = _get_use_v2(request) + config = create_test_config(model_name, dtensor_v2=use_v2, **specified_config) + + if enable_loras: + update_lora_config(config, **lora_config) + + tokenizer = get_tokenizer(config["tokenizer"]) + print(f"Creating {mode} Policy with {specified_config}...") + policy = Policy( + cluster=two_gpu_virtual_cluster, + config=config, + tokenizer=tokenizer, + init_reference_model=False, + ) + print("Creating test batch...") + data = create_test_batch(mode=mode) + + if mode == "train": + # Create loss function + loss_fn: LossFunction = SimpleLoss() + yield policy, data, loss_fn + elif mode == "logprob": + token_logprobs = calculate_token_logprobs(model_name, data) + yield policy, data, token_logprobs + + except Exception as e: + print(f"Error during setup: {e}") + pytest.skip(f"Setup failed: {e}") + finally: + print("Cleaning up resources for test") + if policy: + policy.shutdown() @pytest.fixture def policy_setup(request, two_gpu_virtual_cluster, tiny_llama_model_path): """Setup and teardown for policy tests - creates a virtual cluster and policy.""" - use_v2 = request.param if hasattr(request, "param") else False - config = create_test_config(tiny_llama_model_path, dtensor_v2=use_v2) + params = request.param if hasattr(request, "param") else {} + use_v2 = params.get("dtensor_v2", False) + enable_loras = params.get("enable_loras", False) + + config = create_test_config( + tiny_llama_model_path, dtensor_v2=use_v2, enable_loras=enable_loras + ) tokenizer = get_tokenizer(config["tokenizer"]) config["generation"] = configure_generation_config(config["generation"], tokenizer) @@ -148,9 +320,223 @@ def policy_setup(request, two_gpu_virtual_cluster, tiny_llama_model_path): policy.shutdown() +@pytest.fixture( + params=[ + # model_fixture_name tp cp sp cpu act + ("tiny_llama_model_path", 1, 1, False, False, False), + ("tiny_llama_model_path", 1, 1, True, False, False), + ("tiny_llama_model_path", 1, 1, False, True, False), + ("tiny_llama_model_path", 1, 1, False, False, True), + ("tiny_llama_model_path", 1, 2, False, False, False), + ("tiny_qwen2_model_path", 1, 1, True, True, False), + ("tiny_qwen2_model_path", 1, 1, True, False, True), + ("tiny_qwen2_model_path", 1, 1, False, True, True), + ("tiny_qwen2_model_path", 1, 1, True, True, True), + ("tiny_qwen2_model_path", 1, 2, False, False, False), + ("tiny_qwen3_model_path", 1, 1, True, True, False), + ("tiny_qwen3_model_path", 1, 1, True, False, True), + ("tiny_qwen3_model_path", 1, 1, False, True, True), + ("tiny_qwen3_model_path", 1, 1, True, True, True), + ("tiny_qwen3_model_path", 1, 2, False, False, False), + ( + "tiny_gemma3_model_path", + 1, + 1, + True, + True, + False, + ), # gemma3 doesn't support spda + ("tiny_gemma3_model_path", 1, 1, True, False, True), + ("tiny_gemma3_model_path", 1, 1, False, True, True), + ("tiny_gemma3_model_path", 1, 1, True, True, True), + # CP doesn't support gemma3 due to spda input has attent_mask != None. + # Nemotron-H doesn't support SP https://github.com/NVIDIA-NeMo/RL/issues/881 + # ("tiny_nemotron5_h_model_path", 1, 1, True, True, False), + # ("tiny_nemotron5_h_model_path", 1, 1, True, False, True), + # ("tiny_nemotron5_h_model_path", 1, 1, True, True, True), + ("tiny_nemotron5_h_model_path", 1, 1, False, False, False), + ("tiny_nemotron5_h_model_path", 1, 1, False, True, True), + # nemotron5_h doesn't support cp + ] +) +def training_setup(request, two_gpu_virtual_cluster): + """Setup and teardown specifically for training tests.""" + request.param = { + "mode": "train", + "enable_loras": False, + "lora_config": None, + "model_fixture_name": request.param[0], + "specified_config": { + "tp": request.param[1], + "cp": request.param[2], + "sp": request.param[3], + "cpu_offload": request.param[4], + "activation_checkpointing": request.param[5], + }, + } + yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) + + +@pytest.fixture( + params=[ + # TP=2, CP=1 + ("tiny_qwen2_model_path", 2, 1, False, True, False), + ("tiny_qwen2_model_path", 2, 1, False, False, False), + ("tiny_llama_model_path", 2, 1, False, False, False), + ("tiny_llama_model_path", 2, 1, False, True, False), + ("tiny_llama_model_path", 2, 1, False, True, True), + ("tiny_qwen3_model_path", 2, 1, False, True, False), + ("tiny_qwen3_model_path", 2, 1, False, False, False), + ("tiny_gemma3_model_path", 2, 1, False, True, False), + ("tiny_gemma3_model_path", 2, 1, False, False, False), + # TP=1, CP=2 + ("tiny_qwen2_model_path", 1, 2, False, True, False), + ("tiny_qwen2_model_path", 1, 2, False, False, False), + ("tiny_llama_model_path", 1, 2, False, False, False), + ("tiny_llama_model_path", 1, 2, False, True, False), + ("tiny_llama_model_path", 1, 2, False, True, True), + ("tiny_qwen3_model_path", 1, 2, False, True, False), + ("tiny_qwen3_model_path", 1, 2, False, False, False), + ] +) +def logprob_setup(request, two_gpu_virtual_cluster): + """Setup and teardown specifically for logprob tests.""" + request.param = { + "mode": "logprob", + "enable_loras": False, + "lora_config": None, + "model_fixture_name": request.param[0], + "specified_config": { + "tp": request.param[1], + "cp": request.param[2], + "sp": request.param[3], + "cpu_offload": request.param[4], + "activation_checkpointing": request.param[5], + }, + } + yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) + + +@pytest.fixture( + params=[ + # model_name, target_modules, exclude_modules, match_all_linear, dim, alpha, dropout, dropout_position, lora_A_init, use_triton + ("tiny_llama_model_path", [], [], True, 16, 32, 0.0, "post", "xavier", True), + ("tiny_llama_model_path", [], [], True, 128, 32, 0.0, "post", "xavier", True), + ("tiny_qwen2_model_path", [], [], True, 32, 32, 0.0, "pre", "xavier", True), + ( + "tiny_qwen2_model_path", + ["q_proj", "k_proj", "*gate_proj*", "*up_proj*", "*down_proj*"], + [], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ( + "tiny_qwen2_model_path", + [], + ["q_proj", "k_proj"], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ] +) +def training_with_lora_setup(request, two_gpu_virtual_cluster): + """Setup and teardown specifically for training with lora tests.""" + request.param = { + "mode": "train", + "enable_loras": True, + "model_fixture_name": request.param[0], + "specified_config": {}, + "lora_config": { + "target_modules": request.param[1], + "exclude_modules": request.param[2], + "match_all_linear": request.param[3], + "dim": request.param[4], + "alpha": request.param[5], + "dropout": request.param[6], + "dropout_position": request.param[7], + "lora_A_init": request.param[8], + "use_triton": request.param[9], + }, + } + yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) + + +@pytest.fixture( + params=[ + # model_name, target_modules, exclude_modules, match_all_linear, dim, alpha, dropout, dropout_position, lora_A_init, use_triton + ("tiny_llama_model_path", [], [], True, 16, 32, 0.0, "post", "xavier", True), + ("tiny_llama_model_path", [], [], True, 128, 32, 0.0, "post", "xavier", True), + ("tiny_qwen2_model_path", [], [], True, 32, 32, 0.0, "pre", "xavier", True), + ( + "tiny_qwen2_model_path", + ["q_proj", "k_proj", "*gate_proj*", "*up_proj*", "*down_proj*"], + [], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ( + "tiny_qwen2_model_path", + [], + ["q_proj", "k_proj"], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ] +) +def logprob_with_lora_setup(request, two_gpu_virtual_cluster): + """Setup and teardown specifically for logprob with lora tests.""" + request.param = { + "mode": "logprob", + "enable_loras": True, + "model_fixture_name": request.param[0], + "specified_config": {}, + "lora_config": { + "target_modules": request.param[1], + "exclude_modules": request.param[2], + "match_all_linear": request.param[3], + "dim": request.param[4], + "alpha": request.param[5], + "dropout": request.param[6], + "dropout_position": request.param[7], + "lora_A_init": request.param[8], + "use_triton": request.param[9], + }, + } + yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) + + @pytest.mark.hf_gated @pytest.mark.timeout(360) -@pytest.mark.parametrize("policy_setup", [True, False], indirect=True) +# @pytest.mark.parametrize("policy_setup", [True, False], indirect=True) +@pytest.mark.parametrize( + "policy_setup", + [ + {"dtensor_v2": True, "enable_loras": False}, + {"dtensor_v2": True, "enable_loras": True}, + {"dtensor_v2": False, "enable_loras": False}, + ], + indirect=True, +) def test_lm_policy_init(policy_setup): policy = policy_setup @@ -227,153 +613,12 @@ def test_lm_policy_init(policy_setup): ) -@pytest.fixture -def training_setup(request, two_gpu_virtual_cluster): - """Setup and teardown specifically for training tests.""" - # Get the use_v2 parameter from the test function - use_v2 = getattr(request.function, "pytestmark", []) - use_v2_value = False - for mark in use_v2: - if ( - hasattr(mark, "args") - and len(mark.args) > 1 - and "use_v2" in str(mark.args[0]) - ): - for param_set in mark.args[1]: - if isinstance(param_set, bool): - use_v2_value = param_set - break - - # If multiple parametrize decorators, we need to check the node id - if hasattr(request, "node") and hasattr(request.node, "callspec"): - if "use_v2" in request.node.callspec.params: - use_v2_value = request.node.callspec.params["use_v2"] - - ( - model_fixture_name, - tp, - cp, - sp, - cpu_offload, - activation_checkpointing, - ) = request.param - - # Get the actual model path from the requested fixture - model_name = request.getfixturevalue(model_fixture_name) - policy = None - data = None - loss_fn = None - - try: - config = create_test_config( - model_name, - tp, - cp, - sp, - cpu_offload, - activation_checkpointing, - dtensor_v2=use_v2_value, - ) - tokenizer = get_tokenizer(config["tokenizer"]) - print( - f"Creating training Policy with tp={tp}, cpu_offload={cpu_offload}, sequence_parallel={sp}, activation_checkpointing={activation_checkpointing}..." - ) - policy = Policy( - cluster=two_gpu_virtual_cluster, - config=config, - tokenizer=tokenizer, - init_reference_model=False, - ) - - # Create a test batch - print("Creating test batch...") - # set random seed - torch.manual_seed(42) - - # Create test input_ids and attention_mask - input_ids = torch.randint(0, 32000, (8, 128)) # 8 sequences, each of length 128 - attention_mask = torch.ones(8, 128) - - # Calculate input_lengths (all sequences are full length in this test) - input_lengths = attention_mask.sum(dim=1).to(torch.int32) - - data = BatchedDataDict( - { - "input_ids": input_ids, - "input_lengths": input_lengths, - "attention_mask": attention_mask, # Keep for compatibility with loss functions - "labels": torch.randint(0, 32000, (8, 128)), - "sample_mask": torch.ones(8), - } - ) - - # Create loss function - loss_fn: LossFunction = SimpleLoss() - - # Provide the resources to the test - yield policy, data, loss_fn - - except Exception as e: - print(f"Error during training setup: {e}") - pytest.skip(f"Training setup failed: {e}") - finally: - # Clean up after the test - print("Cleaning up resources for test") - policy.shutdown() - - -@pytest.mark.hf_gated -@pytest.mark.timeout(360) -@pytest.mark.parametrize("use_v2", [True, False]) -@pytest.mark.parametrize( - "training_setup", - [ - # model_fixture_name tp cp sp cpu act - ("tiny_llama_model_path", 1, 1, False, False, False), - ("tiny_llama_model_path", 1, 1, True, False, False), - ("tiny_llama_model_path", 1, 1, False, True, False), - ("tiny_llama_model_path", 1, 1, False, False, True), - ("tiny_llama_model_path", 1, 2, False, False, False), - ("tiny_qwen2_model_path", 1, 1, True, True, False), - ("tiny_qwen2_model_path", 1, 1, True, False, True), - ("tiny_qwen2_model_path", 1, 1, False, True, True), - ("tiny_qwen2_model_path", 1, 1, True, True, True), - ("tiny_qwen2_model_path", 1, 2, False, False, False), - ("tiny_qwen3_model_path", 1, 1, True, True, False), - ("tiny_qwen3_model_path", 1, 1, True, False, True), - ("tiny_qwen3_model_path", 1, 1, False, True, True), - ("tiny_qwen3_model_path", 1, 1, True, True, True), - ("tiny_qwen3_model_path", 1, 2, False, False, False), - ( - "tiny_gemma3_model_path", - 1, - 1, - True, - True, - False, - ), # gemma3 doesn't support spda - ("tiny_gemma3_model_path", 1, 1, True, False, True), - ("tiny_gemma3_model_path", 1, 1, False, True, True), - ("tiny_gemma3_model_path", 1, 1, True, True, True), - # CP doesn't support gemma3 due to spda input has attent_mask != None. - # Nemotron-H doesn't support SP https://github.com/NVIDIA-NeMo/RL/issues/881 - # ("tiny_nemotron5_h_model_path", 1, 1, True, True, False), - # ("tiny_nemotron5_h_model_path", 1, 1, True, False, True), - # ("tiny_nemotron5_h_model_path", 1, 1, True, True, True), - ("tiny_nemotron5_h_model_path", 1, 1, False, False, False), - ("tiny_nemotron5_h_model_path", 1, 1, False, True, True), - # nemotron5_h doesn't support cp - ], - indirect=True, -) -def test_dtensor_worker_training(use_v2, training_setup): +def _test_dtensor_worker_training(policy, data, loss_fn): def verify_loss_tensor(loss_tensor): assert not torch.isnan(loss_tensor).any(), "Loss should not be NaN" assert not torch.isinf(loss_tensor).any(), "Loss should not be Inf" return loss_tensor - policy, data, loss_fn = training_setup - # Verify resources were created properly assert policy is not None, "Training policy was not created properly" assert data is not None, "Test data was not created properly" @@ -423,136 +668,22 @@ def verify_loss_tensor(loss_tensor): ) -@pytest.fixture -def logprob_setup(request, two_gpu_virtual_cluster): - """Setup and teardown specifically for training tests.""" - # Get the use_v2 parameter from the test function - use_v2_value = False - if hasattr(request, "node") and hasattr(request.node, "callspec"): - if "use_v2" in request.node.callspec.params: - use_v2_value = request.node.callspec.params["use_v2"] - - ( - model_fixture_name, - tp, - cp, - sp, - cpu_offload, - activation_checkpointing, - ) = request.param - - # Get the actual model path from the requested fixture - model_name = request.getfixturevalue(model_fixture_name) - policy = None - data = None - - try: - config = create_test_config( - model_name, - tp, - cp, - sp, - cpu_offload, - activation_checkpointing, - dtensor_v2=use_v2_value, - ) - tokenizer = get_tokenizer(config["tokenizer"]) - print( - f"Creating logprob Policy with tp={tp}, cpu_offload={cpu_offload}, sequence_parallel={sp}, activation_checkpointing={activation_checkpointing}..." - ) - policy = Policy( - cluster=two_gpu_virtual_cluster, - config=config, - tokenizer=tokenizer, - init_reference_model=False, - ) - - # Create a test batch - print("Creating test batch...") - # set random seed - torch.manual_seed(66) - - # Create test input_ids and attention_mask - input_ids = torch.randint( - 0, 32000, (8, 128) - ).cuda() # 8 sequences, each of length 128 - attention_mask = torch.ones(8, 128).cuda() - - # Calculate input_lengths (all sequences are full length in this test) - input_lengths = attention_mask.sum(dim=1).to(torch.int32).cuda() - - data = BatchedDataDict( - { - "input_ids": input_ids, - "input_lengths": input_lengths, - "attention_mask": attention_mask, # Keep for compatibility with loss functions - } - ) - - with torch.no_grad(): - # run the log prob of regular hf model here - hf_model = AutoModelForCausalLM.from_pretrained( - model_name, device_map="cuda", torch_dtype=torch.float32 - ) - hf_model.eval() - outputs = hf_model(**data) - - log_probs = torch.nn.functional.log_softmax( - outputs.logits.to(torch.float32), dim=-1 - ) - next_tokens = input_ids[:, 1:] - log_probs = log_probs[:, :-1] - token_logprobs = log_probs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) - token_logprobs = torch.cat( - [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 - ).cpu() - - data = data.to("cpu") - - # Provide the resources to the test - yield policy, data, token_logprobs - - except Exception as e: - print(f"Error during training setup: {e}") - pytest.skip(f"Training setup failed: {e}") - finally: - # Clean up after the test - print("Cleaning up resources for test") - policy.shutdown() +@pytest.mark.hf_gated +@pytest.mark.timeout(360) +@pytest.mark.parametrize("use_v2", [True, False]) +def test_dtensor_worker_training(use_v2, training_setup): + policy, data, loss_fn = training_setup + _test_dtensor_worker_training(policy, data, loss_fn) @pytest.mark.hf_gated @pytest.mark.timeout(360) -@pytest.mark.parametrize("use_v2", [True, False]) -@pytest.mark.parametrize( - "logprob_setup", - [ - # TP=2, CP=1 - ("tiny_qwen2_model_path", 2, 1, False, True, False), - ("tiny_qwen2_model_path", 2, 1, False, False, False), - ("tiny_llama_model_path", 2, 1, False, False, False), - ("tiny_llama_model_path", 2, 1, False, True, False), - ("tiny_llama_model_path", 2, 1, False, True, True), - ("tiny_qwen3_model_path", 2, 1, False, True, False), - ("tiny_qwen3_model_path", 2, 1, False, False, False), - ("tiny_gemma3_model_path", 2, 1, False, True, False), - ("tiny_gemma3_model_path", 2, 1, False, False, False), - # TP=1, CP=2 - ("tiny_qwen2_model_path", 1, 2, False, True, False), - ("tiny_qwen2_model_path", 1, 2, False, False, False), - ("tiny_llama_model_path", 1, 2, False, False, False), - ("tiny_llama_model_path", 1, 2, False, True, False), - ("tiny_llama_model_path", 1, 2, False, True, True), - ("tiny_qwen3_model_path", 1, 2, False, True, False), - ("tiny_qwen3_model_path", 1, 2, False, False, False), - ], - indirect=True, -) -def test_dtensor_worker_logprob_tp2_or_cp2_matches_unsharded(use_v2, logprob_setup): - policy, data, logprobs = logprob_setup +def test_dtensor_worker_training_with_lora(training_with_lora_setup): + policy, data, loss_fn = training_with_lora_setup + _test_dtensor_worker_training(policy, data, loss_fn) + +def _test_dtensor_worker_logprob(policy, data, logprobs): # Verify resources were created properly assert policy is not None, "Policy was not created properly" assert data is not None, "Test data was not created properly" @@ -567,6 +698,21 @@ def test_dtensor_worker_logprob_tp2_or_cp2_matches_unsharded(use_v2, logprob_set ) +@pytest.mark.hf_gated +@pytest.mark.timeout(360) +@pytest.mark.parametrize("use_v2", [True, False]) +def test_dtensor_worker_logprob_tp2_or_cp2_matches_unsharded(use_v2, logprob_setup): + policy, data, logprobs = logprob_setup + _test_dtensor_worker_logprob(policy, data, logprobs) + + +@pytest.mark.hf_gated +@pytest.mark.timeout(360) +def test_dtensor_worker_logprob_with_lora(logprob_with_lora_setup): + policy, data, logprobs = logprob_with_lora_setup + _test_dtensor_worker_logprob(policy, data, logprobs) + + @pytest.mark.hf_gated @pytest.mark.parametrize("use_v2", [True, False]) def test_dtensor_tp_and_tied_model_with_custom_parallel_plan( diff --git a/tests/unit/utils/test_automodel_checkpoint.py b/tests/unit/utils/test_automodel_checkpoint.py index 9906a1522f..8ca930b8aa 100644 --- a/tests/unit/utils/test_automodel_checkpoint.py +++ b/tests/unit/utils/test_automodel_checkpoint.py @@ -26,6 +26,11 @@ except ImportError: pytest.skip("nemo_automodel not available", allow_module_level=True) +from nemo_automodel.components._peft.lora import ( + PeftConfig, + apply_lora_to_linear_modules, +) + from nemo_rl.utils.automodel_checkpoint import ( detect_checkpoint_format, load_checkpoint, @@ -52,6 +57,9 @@ def forward(self, x): x = layer(x) return x + def apply_lora(self, lora_config: PeftConfig): + apply_lora_to_linear_modules(self, lora_config) + @pytest.fixture def mock_model(): @@ -66,6 +74,53 @@ def mock_optimizer(): return torch.optim.Adam(model.parameters()) +@pytest.fixture +def mock_lora_config(): + """Create a simple mock LORA configuration for testing.""" + return PeftConfig( + target_modules=[], + match_all_linear=True, + dim=2, + alpha=2, + dropout=0.1, + dropout_position="post", + lora_A_init="xavier", + use_triton=False, + ) + + +@pytest.fixture +def mock_distributed(): + """Mock torch.distributed calls for non-distributed tests.""" + with ( + patch("torch.distributed.is_initialized", return_value=False), + patch("torch.distributed.get_rank", return_value=0), + ): + yield + torch.distributed.destroy_process_group() + + +@pytest.fixture +def init_distributed(): + """Initialize a single-process distributed environment for testing.""" + + # Only initialize if not already initialized + if not torch.distributed.is_initialized(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" # Free port + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + # Use gloo backend for CPU-only tests + torch.distributed.init_process_group(backend="gloo", rank=0, world_size=1) + + yield + + # Cleanup + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + @pytest.mark.automodel class TestDetectCheckpointFormat: """Test the detect_checkpoint_format function.""" @@ -418,3 +473,83 @@ def test_save_and_load_model_and_optimizer(self, mock_experiment): check_dict_equality(new_model.state_dict(), original_model_state) check_dict_equality(new_optimizer.state_dict(), original_optimizer_state) assert new_scheduler.state_dict() == original_scheduler_state + + def test_save_and_load_model_with_lora( + self, mock_experiment, mock_lora_config, init_distributed + ): + """Test saving and loading both model and optimizer with LORA.""" + test_model, _, _ = mock_experiment + lora_config = mock_lora_config + + test_model.apply_lora(lora_config) + lora_state_dict = test_model.state_dict() + + # Assert LoRA weights exist for layers.0 (Linear 4->4) + assert "layers.0.lora_A.weight" in lora_state_dict, ( + "layers.0.lora_A.weight not found" + ) + assert "layers.0.lora_B.weight" in lora_state_dict, ( + "layers.0.lora_B.weight not found" + ) + + # Assert LoRA weights exist for layers.3 (Linear 4->1) + assert "layers.3.lora_A.weight" in lora_state_dict, ( + "layers.3.lora_A.weight not found" + ) + assert "layers.3.lora_B.weight" in lora_state_dict, ( + "layers.3.lora_B.weight not found" + ) + + assert lora_state_dict["layers.0.lora_A.weight"].shape == (2, 4), ( + f"Expected layers.0.lora_A.weight shape (2, 4), got {lora_state_dict['layers.0.lora_A.weight'].shape}" + ) + assert lora_state_dict["layers.0.lora_B.weight"].shape == (4, 2), ( + f"Expected layers.0.lora_B.weight shape (4, 2), got {lora_state_dict['layers.0.lora_B.weight'].shape}" + ) + + # For layers.3: Linear(4, 1) with dim=2 + # lora_A: (dim, in_features) = (2, 4) + # lora_B: (out_features, dim) = (1, 2) + assert lora_state_dict["layers.3.lora_A.weight"].shape == (2, 4), ( + f"Expected layers.3.lora_A.weight shape (2, 4), got {lora_state_dict['layers.3.lora_A.weight'].shape}" + ) + assert lora_state_dict["layers.3.lora_B.weight"].shape == (1, 2), ( + f"Expected layers.3.lora_B.weight shape (1, 2), got {lora_state_dict['layers.3.lora_B.weight'].shape}" + ) + + initial_distribute = torch.distributed.is_initialized() + print(f"Initial distribute: {initial_distribute}") + + with TemporaryDirectory() as tmp_dir: + weights_path = os.path.join(tmp_dir, "test_model") + save_checkpoint( + model=test_model, + weights_path=weights_path, + model_save_format="safetensors", + is_peft=True, + peft_config=lora_config, + ) + + # Verify files are created + assert os.path.exists(weights_path) + files = os.listdir(os.path.join(weights_path, "model")) + assert any(f.endswith(".safetensors") for f in files) + + # Create a new model with different weights + new_model = TestModel() + new_model.apply_lora(lora_config) + # Initialize with different values + for param in new_model.parameters(): + param.data.fill_(999.0) + + # Load the checkpoint for peft need distributed(refer to nemo_automodel/components/checkpoint/stateful_wrappers.py:load_state_dict) + load_checkpoint(model=new_model, weights_path=weights_path) + # peft only save lora weights, so we need to filter out the non-lora weights + lora_params_original = { + k: v for k, v in lora_state_dict.items() if "lora" in k + } + lora_params_loaded = { + k: v for k, v in new_model.state_dict().items() if "lora" in k + } + # Verify the weights match the original + check_dict_equality(lora_params_loaded, lora_params_original) From fbe5d582ce3fe9d2a364b7ce1d6fd88a9c4478a2 Mon Sep 17 00:00:00 2001 From: ruit Date: Sun, 30 Nov 2025 00:35:56 -0800 Subject: [PATCH 06/10] refactor: update .pre-commit-config.yaml to enable minimize-check hooks for llm and vlm recipes; remove unused sft-llama3.1-8b-1n8g-dtensor-lora configuration and related test scripts; fix tokenizer model path in unit tests Signed-off-by: ruit --- .gitignore | 2 +- .pre-commit-config.yaml | 50 ++++++++--------- .../sft-llama3.1-8b-1n8g-dtensor-lora.yaml | 53 ------------------- .../llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml | 10 ---- .../llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh | 43 --------------- tests/test_suites/nightly.txt | 2 - tests/unit/conftest.py | 6 +-- 7 files changed, 28 insertions(+), 138 deletions(-) delete mode 100644 examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-dtensor-lora.yaml delete mode 100644 tests/test_suites/llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh diff --git a/.gitignore b/.gitignore index 96e5afddec..5d5611d1c2 100644 --- a/.gitignore +++ b/.gitignore @@ -45,4 +45,4 @@ code_snapshots*/ # Runtime env *runtime_env.yaml -!default_runtime_env.yaml \ No newline at end of file +!default_runtime_env.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ee863c39fe..879c913a6e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -63,28 +63,28 @@ repos: # # If this check is disruptive, you can disable the pre-commit hook locally. However, before a recipe # is accepted upstream, we expect the config to be minimized. - # - repo: local - # hooks: - # - id: configs-minimize-check-llm - # name: minimize-check llm recipes - # language: system - # pass_filenames: false - # entry: bash - # args: - # - -lc - # - | - # set -euo pipefail - # base="examples/configs/dpo.yaml"; for f in examples/configs/recipes/llm/dpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done - # base="examples/configs/grpo_math_1B.yaml"; for f in examples/configs/recipes/llm/grpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done - # base="examples/configs/sft.yaml"; for f in examples/configs/recipes/llm/sft-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done - # base="examples/configs/distillation_math.yaml"; for f in examples/configs/recipes/llm/distillation-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done - # - id: configs-minimize-check-vlm - # name: minimize-check vlm recipes - # language: system - # pass_filenames: false - # entry: bash - # args: - # - -lc - # - | - # set -euo pipefail - # base="examples/configs/vlm_grpo_3B.yaml"; for f in examples/configs/recipes/vlm/vlm_grpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + - repo: local + hooks: + - id: configs-minimize-check-llm + name: minimize-check llm recipes + language: system + pass_filenames: false + entry: bash + args: + - -lc + - | + set -euo pipefail + base="examples/configs/dpo.yaml"; for f in examples/configs/recipes/llm/dpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + base="examples/configs/grpo_math_1B.yaml"; for f in examples/configs/recipes/llm/grpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + base="examples/configs/sft.yaml"; for f in examples/configs/recipes/llm/sft-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + base="examples/configs/distillation_math.yaml"; for f in examples/configs/recipes/llm/distillation-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done + - id: configs-minimize-check-vlm + name: minimize-check vlm recipes + language: system + pass_filenames: false + entry: bash + args: + - -lc + - | + set -euo pipefail + base="examples/configs/vlm_grpo_3B.yaml"; for f in examples/configs/recipes/vlm/vlm_grpo-*.yaml; do [ -e "$f" ] && ./tools/config_cli.py minimize-check "$base" "$f"; done diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-dtensor-lora.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-dtensor-lora.yaml deleted file mode 100644 index 8e88441306..0000000000 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-dtensor-lora.yaml +++ /dev/null @@ -1,53 +0,0 @@ -defaults: ../../sft.yaml -sft: - max_num_steps: 350 - val_period: 20 - val_global_batch_size: 128 - val_micro_batch_size: 2 - val_batches: 8 -checkpointing: - checkpoint_dir: results/sft-tmblog-llama3.1-8b - save_period: 20 -policy: - model_name: meta-llama/Llama-3.1-8B - tokenizer: - name: meta-llama/Llama-3.1-8B-Instruct - chat_template: default - train_global_batch_size: 128 - train_micro_batch_size: 1 - max_total_sequence_length: 4096 - precision: "bfloat16" - dtensor_cfg: - tensor_parallel_size: 1 - _v2: true - lora: - enabled: true - target_modules: [] # match all linear modules takes precendence - exclude_modules: [] - match_all_linear: true - dim: 128 - alpha: 32 - dropout: 0.0 - dropout_position: "post" - lora_A_init: "xavier" - use_triton: true - make_sequence_length_divisible_by: 2 - optimizer: - kwargs: - lr: 2.0e-05 - weight_decay: 0.01 - eps: 1.0e-08 -data: - dataset_name: tulu3 - add_generation_prompt: true - seed: 42 -logger: - log_dir: logs/sft-tmblog-llama3.1-8b - tensorboard_enabled: false - wandb: - project: nemo-rl - name: sft-tmblog-llama3.1-8b - tensorboard: - log_dir: tb_logs-sft-dev-tulu3 -cluster: - gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml index bf34055084..6ce2cb2767 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml @@ -6,18 +6,8 @@ checkpointing: save_period: 100 policy: dtensor_cfg: - _v2: true lora: - enabled: false - target_modules: [] # match all linear modules takes precendence - exclude_modules: [] - match_all_linear: true dim: 32 - alpha: 32 - dropout: 0.0 - dropout_position: "post" - lora_A_init: "xavier" - use_triton: true tokenizer: name: meta-llama/Llama-3.2-1B make_sequence_length_divisible_by: 1 diff --git a/tests/test_suites/llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh b/tests/test_suites/llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh deleted file mode 100644 index 2d6829f314..0000000000 --- a/tests/test_suites/llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -source $SCRIPT_DIR/common.env - -# ===== BEGIN CONFIG ===== -NUM_NODES=1 -STEPS_PER_RUN=50 -MAX_STEPS=50 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up -NUM_MINUTES=30 -# ===== END CONFIG ===== - -exit_if_max_steps_reached - -# Run the experiment -cd $PROJECT_ROOT -uv run examples/run_sft.py \ - --config $CONFIG_PATH \ - sft.max_num_steps=$MAX_STEPS \ - logger.log_dir=$LOG_DIR \ - logger.wandb_enabled=False \ - logger.wandb.project=nemo-rl \ - logger.wandb.name=$EXP_NAME \ - logger.monitor_gpus=True \ - logger.tensorboard_enabled=True \ - checkpointing.enabled=True \ - checkpointing.checkpoint_dir=$CKPT_DIR \ - $@ \ - 2>&1 | tee $RUN_LOG - -# Convert tensorboard logs to json -uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS - -# TODO: memory check will fail due to OOM tracked here https://github.com/NVIDIA-NeMo/RL/issues/263 - -# Only run metrics if the target step is reached -if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then - uv run tests/check_metrics.py $JSON_METRICS \ - 'data["train/loss"]["1"] < 1.0' \ - 'data["train/loss"]["50"] < 0.8' \ - 'max(data["ray/node.0.gpu.0.mem_gb"]) < 50' \ - 'mean(data["timing/train/total_step_time"], 2) < 10' -fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index d54ffc73b5..2a8249cf5f 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -66,8 +66,6 @@ tests/test_suites/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.sh tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp2.sh # dynamic batching tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.sh -# Lora -tests/test_suites/llm/sft-llama3.1-8b-1n8g-dtensor-lora.sh # Functional 32b test tests/test_suites/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.sh diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1932e4b5e0..ab3368185c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -516,8 +516,7 @@ def tiny_llama_model_path(): num_key_value_heads=None, ) model = LlamaForCausalLM(config=config) - tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B") - # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") shutil.rmtree(model_path, ignore_errors=True) model.save_pretrained(model_path) tokenizer.save_pretrained(model_path) @@ -547,8 +546,7 @@ def tiny_llama_tied_model_path(): num_key_value_heads=None, ) model = LlamaForCausalLM(config=config) - tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B") - # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") shutil.rmtree(model_path, ignore_errors=True) model.save_pretrained(model_path) tokenizer.save_pretrained(model_path) From f3ab0ac69aabf3fefe59d8419df0139c4651888c Mon Sep 17 00:00:00 2001 From: ruit Date: Sun, 30 Nov 2025 01:24:56 -0800 Subject: [PATCH 07/10] remove unit test param Signed-off-by: ruit --- tests/unit/models/policy/test_dtensor_worker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index 2ddb79f0fe..99ec9bcfff 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -421,7 +421,6 @@ def logprob_setup(request, two_gpu_virtual_cluster): params=[ # model_name, target_modules, exclude_modules, match_all_linear, dim, alpha, dropout, dropout_position, lora_A_init, use_triton ("tiny_llama_model_path", [], [], True, 16, 32, 0.0, "post", "xavier", True), - ("tiny_llama_model_path", [], [], True, 128, 32, 0.0, "post", "xavier", True), ("tiny_qwen2_model_path", [], [], True, 32, 32, 0.0, "pre", "xavier", True), ( "tiny_qwen2_model_path", @@ -475,7 +474,6 @@ def training_with_lora_setup(request, two_gpu_virtual_cluster): params=[ # model_name, target_modules, exclude_modules, match_all_linear, dim, alpha, dropout, dropout_position, lora_A_init, use_triton ("tiny_llama_model_path", [], [], True, 16, 32, 0.0, "post", "xavier", True), - ("tiny_llama_model_path", [], [], True, 128, 32, 0.0, "post", "xavier", True), ("tiny_qwen2_model_path", [], [], True, 32, 32, 0.0, "pre", "xavier", True), ( "tiny_qwen2_model_path", From 1faf60519c8ee8219f9298883ce10c624c568583 Mon Sep 17 00:00:00 2001 From: ruit Date: Sun, 30 Nov 2025 19:09:18 -0800 Subject: [PATCH 08/10] fix: update LoRA weight initialization method in DTensorPolicyWorkerV2; adjust return value for refit_info to only include weights Signed-off-by: ruit --- .../models/policy/workers/dtensor_policy_worker_v2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 8a6ad86d6c..5323d85670 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -21,6 +21,7 @@ from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import Any, Generator, Optional, cast +import nemo_automodel.components._peft.lora as _lora_mod import ray import torch from accelerate import init_empty_weights @@ -98,10 +99,10 @@ from nemo_rl.utils.packed_tensor import packed_broadcast_producer +# TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc def _patched_init_lora_weights(self, init_method: str): if init_method == "xavier": nn.init.xavier_normal_(self.lora_A.weight.data) - print("Initialized LoRA weights with patched xavier initialization") else: nn.init.kaiming_uniform_(self.lora_A.weight.data, a=math.sqrt(5)) self.lora_B.weight.data.zero_() @@ -242,7 +243,7 @@ def __init__( self.peft_config = None self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"] # patch the init_lora_weights method to use the xavier initialization - # _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights + _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights if self.lora_enabled: # Always use float32 since FSDP requires all parameters to be in the same dtype. # autocast should cast the weights to the correct dtype during the forward pass. @@ -1692,8 +1693,8 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]: else None, "lora_weights": lora_weight_names if self.lora_enabled else None, } - - return refit_info + # Lora have not fully supported in DTensorPolicyWorkerV2 yet, so we only return the weights + return refit_info["weights"] @torch.no_grad() def calibrate_qkv_fp8_scales( From 63c3244d2d7e14bcb96b4809d38e3fb18ed26d03 Mon Sep 17 00:00:00 2001 From: ruit Date: Mon, 1 Dec 2025 21:05:09 -0800 Subject: [PATCH 09/10] remove grpo related code Signed-off-by: ruit --- .../workers/dtensor_policy_worker_v2.py | 61 +++---------------- 1 file changed, 8 insertions(+), 53 deletions(-) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 5323d85670..a77e2d2293 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -99,7 +99,7 @@ from nemo_rl.utils.packed_tensor import packed_broadcast_producer -# TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc +# TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc (https://github.com/NVIDIA-NeMo/RL/issues/1586) def _patched_init_lora_weights(self, init_method: str): if init_method == "xavier": nn.init.xavier_normal_(self.lora_A.weight.data) @@ -1659,42 +1659,13 @@ def return_model_config(self) -> dict[str, Any]: @torch.no_grad() def prepare_refit_info(self) -> Optional[dict[str, Any]]: - """Prepare state dict metadata for weight refitting and IPC streaming. - - Returns: - dict containing: - - 'weights': dict mapping weight names to (shape, dtype) tuples - - 'lora_enabled': bool indicating if LoRA is enabled - - 'lora_config': optional PeftConfig if LoRA is enabled - - 'lora_weights': list of LoRA weight names (when LoRA is enabled) - """ + """Prepare state dict metadata for weight refitting and IPC streaming.""" state_dict_info = {} - lora_weight_names = [] + for name, tensor in self.model.state_dict().items(): + # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective + state_dict_info[name] = (tensor.shape, self.dtype) - # Determine which weights to include based on LoRA status - if self.lora_enabled: - # Only include LoRA weights when LoRA is enabled - for name, tensor in self.model.state_dict().items(): - if self._is_lora_weight(name): - # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective - state_dict_info[name] = (tensor.shape, self.dtype) - lora_weight_names.append(name) - else: - # Include all weights when LoRA is not enabled - for name, tensor in self.model.state_dict().items(): - # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective - state_dict_info[name] = (tensor.shape, self.dtype) - - refit_info = { - "weights": state_dict_info, - "lora_enabled": self.lora_enabled, - "lora_config": self.peft_config.to_dict() - if self.lora_enabled and self.peft_config - else None, - "lora_weights": lora_weight_names if self.lora_enabled else None, - } - # Lora have not fully supported in DTensorPolicyWorkerV2 yet, so we only return the weights - return refit_info["weights"] + return state_dict_info @torch.no_grad() def calibrate_qkv_fp8_scales( @@ -1731,15 +1702,8 @@ def stream_weights_via_ipc_zmq( from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl def dtensor_params_generator(): - """Generator that yields (name, tensor) pairs, converting DTensors to local tensors. - - Only yields LoRA weights when LoRA is enabled, otherwise yields all weights. - """ + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" for name, tensor in self.model.state_dict().items(): - # Skip non-LoRA weights if LoRA is enabled - if self.lora_enabled and not self._is_lora_weight(name): - continue - if isinstance(tensor, DTensor): # Convert DTensor to full tensor for streaming full_tensor = tensor.full_tensor() @@ -1788,17 +1752,8 @@ def _dtensor_post_iter_func(tensor, dtype): # param_iterator will return (name, tensor), we only need tensor dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype) - # Filter state dict to only include LoRA weights if LoRA is enabled - def _filtered_state_dict_iterator(): - """Iterator that yields only LoRA weights when LoRA is enabled.""" - for name, tensor in self.model.state_dict().items(): - # Skip non-LoRA weights if LoRA is enabled - if self.lora_enabled and not self._is_lora_weight(name): - continue - yield (name, tensor) - packed_broadcast_producer( - iterator=_filtered_state_dict_iterator(), + iterator=iter(self.model.state_dict().items()), group=self.model_update_group, src=0, post_iter_func=dtensor_post_iter_func, From 6db171922ad295b9e05c7849df544f954c7879d2 Mon Sep 17 00:00:00 2001 From: ruit Date: Sun, 7 Dec 2025 06:03:53 -0800 Subject: [PATCH 10/10] feat: add LoRA configuration support for parameter-efficient fine-tuning; update related examples and documentation Signed-off-by: ruit --- docs/guides/sft.md | 42 +++ .../llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml | 6 +- examples/configs/sft.yaml | 25 +- nemo_rl/models/policy/__init__.py | 32 +- .../workers/dtensor_policy_worker_v2.py | 6 +- tests/functional/test_automodel_lora_sft.sh | 2 +- tests/unit/models/dtensor/test_lora.py | 333 ++++++++++++++++++ tests/unit/utils/test_automodel_checkpoint.py | 3 - 8 files changed, 415 insertions(+), 34 deletions(-) create mode 100644 tests/unit/models/dtensor/test_lora.py diff --git a/docs/guides/sft.md b/docs/guides/sft.md index 982052f074..e4e4915626 100644 --- a/docs/guides/sft.md +++ b/docs/guides/sft.md @@ -161,3 +161,45 @@ As long as your custom dataset has the `formatted_ds` and `task_spec` attributes ## Evaluate the Trained Model Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities. + + +## LoRA Configuration + +NeMo RL supports LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning. LoRA reduces trainable parameters by using low-rank matrices for weight updates while keeping the base model frozen. + +### Configuration Parameters + +The LoRA configuration is specified under the `policy.lora_cfg` section: + +policy: + lora_cfg: + enabled: False # Set to True to enable LoRA fine-tuning + target_modules: [] # List of module names to apply LoRA + exclude_modules: [] # List of module names to exclude from LoRA + match_all_linear: true # Apply LoRA to all linear layers + dim: 8 # LoRA rank (r): controls adaptation capacity + alpha: 32 # LoRA scaling factor (effective lr = alpha/dim) + dropout: 0.0 # Dropout probability for LoRA layers + dropout_position: "post" # Dropout position: "pre" or "post" + lora_A_init: "xavier" # Initialization method: "xavier" or "uniform" + use_triton: true # Use Triton-optimized kernels + +### Parameter Details +- **`enabled`** (bool): Whether to enable LoRA training +- **`target_modules`** (list): Specific module names to apply LoRA. Empty with `match_all_linear=true` applies to all linear layers +- **`exclude_modules`** (list): Module names to exclude from LoRA +- **`match_all_linear`** (bool): When `true`, applies LoRA to all linear layers (overrides `target_modules`) +- **`dim`** (int): LoRA rank (r). Lower values = fewer parameters but less capacity. Typical: 4, 8, 16, 32, 64 +- **`alpha`** (int): LoRA scaling factor. Effective learning rate multiplier = `alpha/dim`. Typical: 16, 32, 64 +- **`dropout`** (float): Dropout probability for regularization +- **`dropout_position`** (str): Apply dropout before ("pre") or after ("post") LoRA +- **`lora_A_init`** (str): Initialization method for LoRA A matrix +- **`use_triton`** (bool): Use Triton-optimized kernels for better performance. Used for dtensor v2 only. **Note**: [Automodel not support triton for TP > 1](https://github.com/NVIDIA-NeMo/Automodel/blob/b2db55eee98dfe81a8bfe5e23ac4e57afd8ab261/nemo_automodel/recipes/llm/train_ft.py#L199). Set to `false` when `tensor_parallel_size > 1` to avoid compatibility issues + +### Example Usage + +```bash +uv run examples/run_sft.py policy.lora_cfg.enabled=True +``` + +For more details on LoRA, see [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml index 6ce2cb2767..82950cf153 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml @@ -5,12 +5,12 @@ checkpointing: checkpoint_dir: results/sft-llama3.2-1b-1n8g-fsdp2tp1 save_period: 100 policy: - dtensor_cfg: - lora: - dim: 32 tokenizer: name: meta-llama/Llama-3.2-1B make_sequence_length_divisible_by: 1 + lora_cfg: + enabled: true + dim: 32 data: dataset_name: openmathinstruct2 prompt_file: examples/prompts/math.txt diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 4d97d1f6dd..9316ecccba 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -45,18 +45,19 @@ policy: tensor_parallel_size: 1 context_parallel_size: 1 custom_parallel_plan: null - lora: - enabled: false - target_modules: [] # match all linear modules takes precendence - exclude_modules: [] - match_all_linear: true - dim: 8 - alpha: 32 - dropout: 0.0 - dropout_position: "post" - lora_A_init: "xavier" - lora_dtype: ${policy.precision} - use_triton: true + + # LoRA (Low-Rank Adaptation) Configuration + lora_cfg: + enabled: False # Set to True to enable LoRA fine-tuning + target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers) + exclude_modules: [] # List of module names to exclude from LoRA + match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules) + dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64 + alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64 + dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout) + dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA) + lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform" + use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1 dynamic_batching: enabled: false diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index c1ea023b01..aa3dad5203 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -21,19 +21,6 @@ class DTensorConfigDisabled(TypedDict): enabled: Literal[False] -class LoRAConfig(TypedDict): - enabled: bool - target_modules: NotRequired[list[str]] - exclude_modules: NotRequired[list[str]] - match_all_linear: NotRequired[bool] - dim: NotRequired[int] - alpha: NotRequired[int] - dropout: NotRequired[float] - dropout_position: NotRequired[Literal["pre", "post"]] - lora_A_init: NotRequired[str] - use_triton: NotRequired[bool] - - class DTensorConfig(TypedDict): enabled: Literal[True] env_vars: NotRequired[dict[str, str] | None] @@ -45,7 +32,6 @@ class DTensorConfig(TypedDict): context_parallel_size: int custom_parallel_plan: str | None clear_cache_every_n_steps: NotRequired[int | None] - lora: NotRequired[LoRAConfig | None] class SequencePackingConfigDisabled(TypedDict): @@ -145,6 +131,23 @@ class MegatronConfig(TypedDict): distributed_data_parallel_config: MegatronDDPConfig +class LoRAConfigDisabled(TypedDict): + enabled: Literal[False] + + +class LoRAConfig(TypedDict): + enabled: Literal[True] + target_modules: list[str] + exclude_modules: list[str] + match_all_linear: bool + dim: int + alpha: int + dropout: float + dropout_position: Literal["pre", "post"] + lora_A_init: str + use_triton: bool + + class TokenizerConfig(TypedDict): name: str chat_template: NotRequired[str] @@ -203,6 +206,7 @@ class PolicyConfig(TypedDict): reward_model_cfg: NotRequired[RewardModelConfig] dtensor_cfg: DTensorConfig | DTensorConfigDisabled megatron_cfg: NotRequired[MegatronConfig | MegatronConfigDisabled] + lora_cfg: NotRequired[LoRAConfig | LoRAConfigDisabled] hf_config_overrides: NotRequired[dict[str, Any]] dynamic_batching: DynamicBatchingConfig | DynamicBatchingConfigDisabled sequence_packing: NotRequired[SequencePackingConfig | SequencePackingConfigDisabled] diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index a77e2d2293..4b1202208c 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -239,12 +239,16 @@ def __init__( model_state_dict_keys = None # lora config - lora_cfg = self.cfg["dtensor_cfg"].get("lora", None) + lora_cfg = self.cfg.get("lora_cfg", None) self.peft_config = None self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"] # patch the init_lora_weights method to use the xavier initialization _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights if self.lora_enabled: + if self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1: + assert not lora_cfg["use_triton"], ( + "Triton is not supported when tensor_parallel_size > 1" + ) # Always use float32 since FSDP requires all parameters to be in the same dtype. # autocast should cast the weights to the correct dtype during the forward pass. cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"} diff --git a/tests/functional/test_automodel_lora_sft.sh b/tests/functional/test_automodel_lora_sft.sh index cd8fc5b057..b2baf88170 100644 --- a/tests/functional/test_automodel_lora_sft.sh +++ b/tests/functional/test_automodel_lora_sft.sh @@ -36,7 +36,7 @@ uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJE checkpointing.enabled=true \ checkpointing.save_period=3 \ checkpointing.checkpoint_dir=/tmp/lora_sft_checkpoints \ - $@ \ + "$@" \ 2>&1 | tee $RUN_LOG uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS diff --git a/tests/unit/models/dtensor/test_lora.py b/tests/unit/models/dtensor/test_lora.py new file mode 100644 index 0000000000..eac60eb6fe --- /dev/null +++ b/tests/unit/models/dtensor/test_lora.py @@ -0,0 +1,333 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math + +import pytest + +# Skip entire module if nemo_automodel is not available +pytest_plugins = [] +try: + import nemo_automodel # noqa: F401 +except ImportError: + pytest.skip("nemo_automodel not available", allow_module_level=True) + +import torch +import torch.nn as nn +from nemo_automodel.components._peft.lora import ( + LinearLoRA, + PeftConfig, + apply_lora_to_linear_modules, +) + +from nemo_rl.models.policy.workers.dtensor_policy_worker_v2 import ( + _patched_init_lora_weights, +) + + +class SimpleLoraMock(nn.Module): + """Simple mock LoRA module for testing initialization.""" + + def __init__(self, in_features=128, out_features=256, lora_dim=8): + super().__init__() + self.lora_A = nn.Linear(in_features, lora_dim, bias=False) + self.lora_B = nn.Linear(lora_dim, out_features, bias=False) + + +@pytest.mark.parametrize("init_method", ["xavier"]) +def test_lora_init_differs_from_upstream_buggy_version(init_method): + """ + Test that our patched LoRA initialization differs from the buggy upstream version. + + Remove this test once Automodel is bumped to commit 2d20e33a19d5e53a271b1403b507475e68ad14dc or later. + + Issue: https://github.com/NVIDIA-NeMo/RL/issues/1586 + """ + torch.manual_seed(42) + + # Create two identical LoRA modules + lora_buggy = LinearLoRA(nn.Linear(16, 16)) + lora_patched = LinearLoRA(nn.Linear(16, 16)) + + # Copy initial weights to ensure identical starting point + lora_patched.lora_A.weight.data.copy_(lora_buggy.lora_A.weight.data) + lora_patched.lora_B.weight.data.copy_(lora_buggy.lora_B.weight.data) + + # Apply buggy upstream initialization + torch.manual_seed(42) + lora_buggy.init_lora_weights(init_method) + + # Apply our patched initialization + torch.manual_seed(42) + _patched_init_lora_weights(lora_patched, init_method) + + # For xavier method, they should differ (that's the bug) + + # Assert that weights differ due to the upstream bug + are_equal_A = torch.allclose( + lora_buggy.lora_A.weight.data, + lora_patched.lora_A.weight.data, + atol=1e-6, + rtol=1e-6, + ) + + assert not are_equal_A, ( + "LoRA A weights should differ for xavier initialization. " + "If this assertion fails, the upstream bug has been fixed in Automodel. " + "You can:\n" + "1. Remove the patch in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py\n" + "2. Remove the patching call\n" + "3. Close issue: https://github.com/NVIDIA-NeMo/RL/issues/1586\n" + "4. Delete this test" + ) + + # LoRA B should always be zero-initialized (both implementations do this correctly) + are_equal_B = torch.allclose( + lora_buggy.lora_B.weight.data, + lora_patched.lora_B.weight.data, + atol=0, + rtol=0, + ) + assert are_equal_B, "LoRA B weights should both be zero" + assert torch.all(lora_buggy.lora_B.weight.data == 0), ( + "LoRA B should be zero-initialized" + ) + assert torch.all(lora_patched.lora_B.weight.data == 0), ( + "LoRA B should be zero-initialized" + ) + + +def test_lora_init_statistical_properties(): + """ + Additional test to verify the statistical properties of the patched initialization. + This ensures our fix produces reasonable weight distributions. + """ + torch.manual_seed(42) + + lora = SimpleLoraMock(in_features=512, out_features=1024, lora_dim=32) + + # Test xavier initialization + _patched_init_lora_weights(lora, "xavier") + + # Xavier normal should have mean ≈ 0 and specific std + mean_A = lora.lora_A.weight.data.mean().item() + std_A = lora.lora_A.weight.data.std().item() + + assert abs(mean_A) < 0.1, f"Xavier normal should have mean ≈ 0, got {mean_A}" + # Xavier normal std = sqrt(2 / (fan_in + fan_out)) + expected_std = math.sqrt(2.0 / (512 + 32)) + assert abs(std_A - expected_std) < 0.05, ( + f"Xavier normal std should be ≈ {expected_std}, got {std_A}" + ) + + # LoRA B should be all zeros + assert torch.all(lora.lora_B.weight.data == 0), "LoRA B should be zero-initialized" + + # Test kaiming initialization + lora2 = SimpleLoraMock(in_features=512, out_features=1024, lora_dim=32) + _patched_init_lora_weights(lora2, "kaiming") + + mean_A2 = lora2.lora_A.weight.data.mean().item() + assert abs(mean_A2) < 0.1, f"Kaiming should have mean ≈ 0, got {mean_A2}" + assert torch.all(lora2.lora_B.weight.data == 0), "LoRA B should be zero-initialized" + + +class DummyModel(nn.Module): + """A dummy neural network model with two linear layers used for testing LoRA injection.""" + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(16, 16) + self.linear2 = nn.Linear(16, 16) + self.config = {} + + def forward(self, x): + """Forward pass through two linear layers with ReLU activation in between.""" + x = self.linear1(x).relu() + x = self.linear2(x) + return x + + +class DummyModelNoConfig(nn.Module): + """Same as DummyModel but without a `config` attribute.""" + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(16, 16) + self.linear2 = nn.Linear(16, 16) + + def forward(self, x): + x = self.linear1(x).relu() + x = self.linear2(x) + return x + + +@pytest.fixture +def dummy_input(): + """Provides a dummy input tensor for model testing.""" + return torch.randn(2, 16, requires_grad=True) + + +@pytest.fixture +def model(): + """Instantiates and returns a DummyModel instance.""" + return DummyModel() + + +@pytest.fixture +def model_no_config(): + """Instantiates a model that has no `config` attr.""" + return DummyModelNoConfig() + + +def test_lora_patch_applies_to_selected_module(model): + """Tests that LoRA is only applied to specified target modules.""" + apply_lora_to_linear_modules( + model, PeftConfig(target_modules=["linear1"], dim=4, alpha=8) + ) + assert isinstance(model.linear1, LinearLoRA) + assert not isinstance(model.linear2, LinearLoRA) + + +def test_lora_patch_on_model_without_config(model_no_config): + """LoRA should still patch correctly even if the model lacks `config`.""" + apply_lora_to_linear_modules( + model_no_config, PeftConfig(target_modules=["linear1"], dim=4, alpha=8) + ) + assert isinstance(model_no_config.linear1, LinearLoRA) + assert not isinstance(model_no_config.linear2, LinearLoRA) + + +def test_lora_layers_are_trainable(): + """Ensures that LoRA layers are trainable while base weights remain frozen.""" + base = nn.Linear(16, 16) + lora = LinearLoRA(base, dim=4, alpha=8) + + assert lora.weight.requires_grad is False + assert lora.lora_A.weight.requires_grad + assert lora.lora_B.weight.requires_grad + if lora.bias is not None: + assert lora.bias.requires_grad is False + + +def test_forward_output_consistency(dummy_input): + """Verifies that model output shape remains the same after LoRA patching, + but values change due to the added LoRA components. + """ + base = DummyModel() + model = DummyModel() + apply_lora_to_linear_modules( + model, PeftConfig(target_modules=["linear1"], dim=4, alpha=8) + ) + + base.eval() + model.eval() + + with torch.no_grad(): + out1 = base(dummy_input) + out2 = model(dummy_input) + + assert out1.shape == out2.shape + assert not torch.allclose(out1, out2), "Output should differ due to LoRA injection" + + +def test_backward_pass(dummy_input): + """Checks that backpropagation works and gradients are correctly computed + when LoRA is applied. + """ + model = DummyModel() + apply_lora_to_linear_modules( + model, PeftConfig(target_modules=["linear1"], dim=4, alpha=8) + ) + output = model(dummy_input) + loss = output.sum() + loss.backward() + + grads = [p.grad for p in model.parameters() if p.requires_grad] + assert any(g is not None for g in grads), "Some parameters should receive gradients" + assert all(torch.isfinite(g).all() for g in grads if g is not None), ( + "Gradients should be finite" + ) + + +def test_backward_pass_without_config(dummy_input, model_no_config): + """Backward pass must succeed on a model without `config`.""" + apply_lora_to_linear_modules( + model_no_config, PeftConfig(target_modules=["linear1"], dim=4, alpha=8) + ) + out = model_no_config(dummy_input) + loss = out.sum() + loss.backward() + + grads = [p.grad for p in model_no_config.parameters() if p.requires_grad] + assert any(g is not None for g in grads) + assert all(torch.isfinite(g).all() for g in grads if g is not None) + + +def test_apply_lora_respects_wildcard(model): + """Validates that wildcard matching correctly applies LoRA to all matching modules.""" + apply_lora_to_linear_modules( + model, PeftConfig(target_modules=[".*"], dim=4, alpha=8) + ) + assert isinstance(model.linear1, LinearLoRA) + assert isinstance(model.linear2, LinearLoRA) + + +def test_no_patch_on_non_matching_module(model): + """Confirms that no modules are patched if target pattern doesn't match any names.""" + apply_lora_to_linear_modules( + model, PeftConfig(target_modules=["nonexistent_module"], dim=4, alpha=8) + ) + assert not isinstance(model.linear1, LinearLoRA) + assert not isinstance(model.linear2, LinearLoRA) + + +def test_lora_patch_with_dtype_string(model): + """Tests that LoRA can be applied with dtype specified as string.""" + apply_lora_to_linear_modules( + model, + PeftConfig( + target_modules=["linear1"], dim=4, alpha=8, lora_dtype="torch.bfloat16" + ), + ) + assert isinstance(model.linear1, LinearLoRA) + assert model.linear1.lora_A.weight.dtype == torch.bfloat16 + assert model.linear1.lora_B.weight.dtype == torch.bfloat16 + assert not isinstance(model.linear2, LinearLoRA) + + +def test_dropout_pre_post_effects(dummy_input): + """Tests that different dropout positions ('pre' vs 'post') lead to different outputs.""" + base = nn.Linear(16, 16) + lora_pre = LinearLoRA(base, dim=4, alpha=8, dropout=0.5, dropout_position="pre") + lora_post = LinearLoRA(base, dim=4, alpha=8, dropout=0.5, dropout_position="post") + + with torch.no_grad(): + lora_pre.lora_A.weight.uniform_() + lora_pre.lora_B.weight.uniform_() + + lora_post.lora_A.weight.copy_(lora_pre.lora_A.weight) + lora_post.lora_B.weight.copy_(lora_pre.lora_B.weight) + + lora_pre.train() + lora_post.train() + + out_pre = lora_pre(dummy_input) + out_post = lora_post(dummy_input) + + assert out_pre.shape == out_post.shape + assert not torch.allclose(out_pre, out_post), ( + "Dropout positions should affect output differently" + ) diff --git a/tests/unit/utils/test_automodel_checkpoint.py b/tests/unit/utils/test_automodel_checkpoint.py index 8ca930b8aa..ba2980908f 100644 --- a/tests/unit/utils/test_automodel_checkpoint.py +++ b/tests/unit/utils/test_automodel_checkpoint.py @@ -517,9 +517,6 @@ def test_save_and_load_model_with_lora( f"Expected layers.3.lora_B.weight shape (1, 2), got {lora_state_dict['layers.3.lora_B.weight'].shape}" ) - initial_distribute = torch.distributed.is_initialized() - print(f"Initial distribute: {initial_distribute}") - with TemporaryDirectory() as tmp_dir: weights_path = os.path.join(tmp_dir, "test_model") save_checkpoint(