Skip to content

Fix native RL trainers training mode#12

Open
omercelik wants to merge 1 commit intoARahim3:mainfrom
omercelik:omercelik/fix-native-rl-train-mode
Open

Fix native RL trainers training mode#12
omercelik wants to merge 1 commit intoARahim3:mainfrom
omercelik:omercelik/fix-native-rl-train-mode

Conversation

@omercelik
Copy link
Copy Markdown

Summary

This fixes native preference-optimization training running with models left in inference mode.

The immediate user-facing failure was native DPO crashing on Qwen3.5-style models with:

ValueError: [Primitive::vjp] Not implemented for CustomKernel

The root cause was not the DPO loss itself. The problem was that the native RL trainer paths were building gradients without explicitly switching the model to training mode first.

For Qwen3.5-style models, that matters because the linear-attention path changes behavior based on self.training. When the model stays in inference mode, it can take an inference-only Metal custom-kernel path. That kernel does not provide the VJP needed for gradient computation, so native training fails.

This PR ensures the native gradient-based RL trainer paths enter training mode before calling nn.value_and_grad(...).

Root Cause

In mlx_tune/rl_trainers.py, native RL training extracted the underlying trainable model and immediately built:

loss_and_grad = nn.value_and_grad(actual_model, loss_fn)

but it never called train() first.

That means wrapped models could still be in inference mode when the forward pass used for backpropagation ran.

For Qwen3.5-style models, this is especially important because the model code uses self.training to decide whether to use the inference-oriented kernel path.

In practice, explicitly calling train() before native DPO training avoids the crash and allows a bf16 DPO pilot step to complete.

What Changed

Added a small helper in mlx_tune/rl_trainers.py:

  • _set_native_train_mode(model, actual_model)

This calls train() on both:

  • the wrapper object, if it exposes train()
  • the unwrapped underlying model used for gradient computation

The helper is then called before native optimizer/gradient setup in the RL trainer paths that actually use native gradient-based optimization:

  • DPOTrainer._train_native()
  • ORPOTrainer._train_native()
  • KTOTrainer._train_native()
  • SimPOTrainer._train_native()

Why The Fix Covers More Than DPO

Although DPO was the reported failure, the same bug pattern existed in other native RL trainer paths.

ORPO, KTO, and SimPO also:

  • unwrap actual_model
  • construct nn.value_and_grad(...)
  • perform native gradient updates
  • previously did not switch the model into training mode first

Since the issue is structural and identical in those paths, fixing only DPO would leave the same latent bug elsewhere.

I did not expand the fix to GRPO because its native path is materially different and does not currently use the same value_and_grad(...) training loop.

What This PR Does Not Change

This PR is intentionally narrow.

It does not:

  • add any fallback-to-SFT behavior
  • change dataset handling or dataset semantics
  • change DPO/ORPO/KTO/SimPO loss definitions
  • change tokenization, padding, or batching behavior
  • alter GRPO behavior

Tests Added

Added a regression test in tests/test_rl_trainers_integration.py that verifies native RL trainer paths enter the loss function with model.training == True.

This test covers:

  • DPO
  • ORPO
  • KTO
  • SimPO

The test uses patched loss functions to directly observe the training flag seen by the model during native training.

Validation

Ran the full pytest suite after the fix.

Result:

  • 504 passed
  • 1 skipped
  • 12 deselected

The focused RL regression also passed, and the full RL integration file passed as part of the full suite.

Manual Smoke Validation

In addition to the automated test coverage, I manually validated the reported failure mode against a real Qwen3.5 bf16 native DPO setup.

Smoke setup:

  • model: mlx-community/Qwen3.5-4B-MLX-bf16
  • trainer: native DPOTrainer
  • data: 8 preference pairs from a prepared DPO dataset
  • config: LoRA adapters, max_steps=1

Result:

  • native DPO completed successfully
  • observed training log included:
    • Step 1/1 | Loss: 0.6914
    • DPO Training Complete!

This is the same code path that previously failed with:

ValueError: [Primitive::vjp] Not implemented for CustomKernel

After the fix, that smoke run no longer reproduces the failure.

Files Changed

  • mlx_tune/rl_trainers.py
  • tests/test_rl_trainers_integration.py

Rationale

The correct fix is to ensure native training code actually runs in training mode before gradient computation.

That matches expected ML training semantics, matches how native SFT-style training is already handled elsewhere, and avoids inference-only code paths being used during backpropagation.

@ARahim3
Copy link
Copy Markdown
Owner

ARahim3 commented Mar 27, 2026

Hi @omercelik, thanks for the detailed PR and the thorough write-up!

This is a well-identified issue — the missing train() call before nn.value_and_grad() in the native RL paths makes total sense as a root cause for the CustomKernel VJP error on Qwen3.5.

I'm currently doing some local work on the RL trainer paths myself (improving end-to-end workflows, adding dedicated examples, etc.), so I'd like to take some time to review this properly against my local changes to make sure everything fits together cleanly. I'll follow up once I've had a chance to go through it in detail.

Really appreciate the contribution and the quality of the PR — the focused scope and the regression test are exactly the right approach.

@ARahim3
Copy link
Copy Markdown
Owner

ARahim3 commented Mar 27, 2026

Hi @omercelik, thanks again for the detailed write-up. While I was working on dedicated End-to-End examples and testing each RL method with Qwen3.5, I ran into the same CustomKernel VJP issue and arrived at the same fix — calling train() before nn.value_and_grad() is the standard approach and the obvious solution once you trace the Gated DeltaNet kernel dispatch.

I've just shipped v0.4.11, which includes this fix across all 5 RL trainers (including GRPO, which I also rewrote), along with new dedicated examples, KTOConfig/SimPOConfig, and other RL improvements. Since this overlaps with your PR, I'm going to close it — but your analysis of the root cause was solid work.

I really appreciate your work. Please let me know if you face any other issues with the training examples or anything else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants