From 63353f085b6825063554f4d034ceaca4f179b707 Mon Sep 17 00:00:00 2001 From: Omer Celik Date: Thu, 26 Mar 2026 23:54:52 +0000 Subject: [PATCH] Fix native RL trainers training mode --- mlx_tune/rl_trainers.py | 14 +++++ tests/test_rl_trainers_integration.py | 82 +++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/mlx_tune/rl_trainers.py b/mlx_tune/rl_trainers.py index d920c10..c6aca34 100644 --- a/mlx_tune/rl_trainers.py +++ b/mlx_tune/rl_trainers.py @@ -105,6 +105,16 @@ def _save_adapters_and_config(model, adapter_path: Path): return False +def _set_native_train_mode(model, actual_model): + """Ensure wrapped native RL models run with training-time behavior enabled.""" + seen = set() + for candidate in (model, actual_model): + if candidate is None or id(candidate) in seen or not hasattr(candidate, "train"): + continue + candidate.train() + seen.add(id(candidate)) + + class DPOConfig: """ Configuration for Direct Preference Optimization training. @@ -442,6 +452,7 @@ def _train_native(self): # Get actual model actual_model = self.model.model if hasattr(self.model, 'model') else self.model + _set_native_train_mode(self.model, actual_model) # Create optimizer lr_schedule = optim.cosine_decay(self.learning_rate, self.iters) @@ -681,6 +692,7 @@ def _train_native(self): print(f"✓ Prepared {len(tokenized_data)} preference pairs") actual_model = self.model.model if hasattr(self.model, 'model') else self.model + _set_native_train_mode(self.model, actual_model) lr_schedule = optim.cosine_decay(self.learning_rate, self.iters) optimizer = optim.AdamW(learning_rate=lr_schedule) @@ -995,6 +1007,7 @@ def train(self): self.model._apply_lora() actual_model = self.model.model if hasattr(self.model, 'model') else self.model + _set_native_train_mode(self.model, actual_model) lr_schedule = optim.cosine_decay(self.learning_rate, self.iters) optimizer = optim.AdamW(learning_rate=lr_schedule) @@ -1127,6 +1140,7 @@ def train(self): print(f"✓ Prepared {len(tokenized_data)} preference pairs") actual_model = self.model.model if hasattr(self.model, 'model') else self.model + _set_native_train_mode(self.model, actual_model) lr_schedule = optim.cosine_decay(self.learning_rate, self.iters) optimizer = optim.AdamW(learning_rate=lr_schedule) diff --git a/tests/test_rl_trainers_integration.py b/tests/test_rl_trainers_integration.py index cfadc55..80956af 100644 --- a/tests/test_rl_trainers_integration.py +++ b/tests/test_rl_trainers_integration.py @@ -155,6 +155,88 @@ def grpo_dataset(): ] +@pytest.mark.integration +@pytest.mark.parametrize( + ("trainer_factory", "loss_attr", "dataset_name"), + [ + ("dpo", "compute_dpo_loss", "preference"), + ("orpo", "compute_orpo_loss", "preference"), + ("kto", "compute_kto_loss", "kto"), + ("simpo", "compute_simpo_loss", "preference"), + ], +) +def test_native_rl_trainers_enable_train_mode( + monkeypatch, + mock_tokenizer, + preference_dataset, + kto_dataset, + tmp_path, + trainer_factory, + loss_attr, + dataset_name, +): + """Native gradient-based RL trainers should switch models into training mode.""" + import mlx_tune.rl_trainers as rl_trainers + from mlx_tune import DPOTrainer, DPOConfig, ORPOTrainer, ORPOConfig, KTOTrainer, SimPOTrainer + + monkeypatch.setattr(rl_trainers, "HAS_NATIVE_TRAINING", True) + monkeypatch.setattr(rl_trainers, "_save_adapters_and_config", lambda *args, **kwargs: True) + + observed = {"training": None} + + def pair_loss(model, chosen_ids, rejected_ids, chosen_lengths, rejected_lengths, *args, **kwargs): + observed["training"] = model.training + loss = model(chosen_ids).mean() - model(rejected_ids).mean() + return loss, mx.array(chosen_lengths.sum() + rejected_lengths.sum()) + + def kto_loss(model, input_ids, lengths, labels, *args, **kwargs): + observed["training"] = model.training + loss = model(input_ids).mean() + labels.mean() * 0.0 + return loss, mx.array(lengths.sum()) + + monkeypatch.setattr(rl_trainers, loss_attr, kto_loss if dataset_name == "kto" else pair_loss) + + model = MockModelWrapper(SmallLanguageModel()) + mx.eval(model.model.parameters()) + + if trainer_factory == "dpo": + trainer = DPOTrainer( + model=model, + train_dataset=preference_dataset, + tokenizer=mock_tokenizer, + args=DPOConfig(max_steps=1, logging_steps=1, save_steps=10, output_dir=str(tmp_path / "dpo")), + ) + elif trainer_factory == "orpo": + trainer = ORPOTrainer( + model=model, + train_dataset=preference_dataset, + tokenizer=mock_tokenizer, + args=ORPOConfig(max_steps=1, logging_steps=1, save_steps=10, output_dir=str(tmp_path / "orpo")), + ) + elif trainer_factory == "kto": + trainer = KTOTrainer( + model=model, + train_dataset=kto_dataset, + tokenizer=mock_tokenizer, + max_steps=1, + logging_steps=1, + output_dir=str(tmp_path / "kto"), + ) + else: + trainer = SimPOTrainer( + model=model, + train_dataset=preference_dataset, + tokenizer=mock_tokenizer, + max_steps=1, + logging_steps=1, + output_dir=str(tmp_path / "simpo"), + ) + + trainer.train() + + assert observed["training"] is True + + # ============================================================================= # DPO TRAINER INTEGRATION TESTS # =============================================================================