Conversation
|
Hi @omercelik, thanks for the detailed PR and the thorough write-up! This is a well-identified issue — the missing I'm currently doing some local work on the RL trainer paths myself (improving end-to-end workflows, adding dedicated examples, etc.), so I'd like to take some time to review this properly against my local changes to make sure everything fits together cleanly. I'll follow up once I've had a chance to go through it in detail. Really appreciate the contribution and the quality of the PR — the focused scope and the regression test are exactly the right approach. |
|
Hi @omercelik, thanks again for the detailed write-up. While I was working on dedicated End-to-End examples and testing each RL method with Qwen3.5, I ran into the same I've just shipped v0.4.11, which includes this fix across all 5 RL trainers (including GRPO, which I also rewrote), along with new dedicated examples, I really appreciate your work. Please let me know if you face any other issues with the training examples or anything else. |
Summary
This fixes native preference-optimization training running with models left in inference mode.
The immediate user-facing failure was native DPO crashing on Qwen3.5-style models with:
The root cause was not the DPO loss itself. The problem was that the native RL trainer paths were building gradients without explicitly switching the model to training mode first.
For Qwen3.5-style models, that matters because the linear-attention path changes behavior based on
self.training. When the model stays in inference mode, it can take an inference-only Metal custom-kernel path. That kernel does not provide the VJP needed for gradient computation, so native training fails.This PR ensures the native gradient-based RL trainer paths enter training mode before calling
nn.value_and_grad(...).Root Cause
In
mlx_tune/rl_trainers.py, native RL training extracted the underlying trainable model and immediately built:but it never called
train()first.That means wrapped models could still be in inference mode when the forward pass used for backpropagation ran.
For Qwen3.5-style models, this is especially important because the model code uses
self.trainingto decide whether to use the inference-oriented kernel path.In practice, explicitly calling
train()before native DPO training avoids the crash and allows a bf16 DPO pilot step to complete.What Changed
Added a small helper in
mlx_tune/rl_trainers.py:_set_native_train_mode(model, actual_model)This calls
train()on both:train()The helper is then called before native optimizer/gradient setup in the RL trainer paths that actually use native gradient-based optimization:
DPOTrainer._train_native()ORPOTrainer._train_native()KTOTrainer._train_native()SimPOTrainer._train_native()Why The Fix Covers More Than DPO
Although DPO was the reported failure, the same bug pattern existed in other native RL trainer paths.
ORPO,KTO, andSimPOalso:actual_modelnn.value_and_grad(...)Since the issue is structural and identical in those paths, fixing only DPO would leave the same latent bug elsewhere.
I did not expand the fix to GRPO because its native path is materially different and does not currently use the same
value_and_grad(...)training loop.What This PR Does Not Change
This PR is intentionally narrow.
It does not:
Tests Added
Added a regression test in
tests/test_rl_trainers_integration.pythat verifies native RL trainer paths enter the loss function withmodel.training == True.This test covers:
The test uses patched loss functions to directly observe the training flag seen by the model during native training.
Validation
Ran the full pytest suite after the fix.
Result:
504 passed1 skipped12 deselectedThe focused RL regression also passed, and the full RL integration file passed as part of the full suite.
Manual Smoke Validation
In addition to the automated test coverage, I manually validated the reported failure mode against a real Qwen3.5 bf16 native DPO setup.
Smoke setup:
mlx-community/Qwen3.5-4B-MLX-bf16DPOTrainermax_steps=1Result:
Step 1/1 | Loss: 0.6914DPO Training Complete!This is the same code path that previously failed with:
After the fix, that smoke run no longer reproduces the failure.
Files Changed
mlx_tune/rl_trainers.pytests/test_rl_trainers_integration.pyRationale
The correct fix is to ensure native training code actually runs in training mode before gradient computation.
That matches expected ML training semantics, matches how native SFT-style training is already handled elsewhere, and avoids inference-only code paths being used during backpropagation.