Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlx_tune/rl_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def _save_adapters_and_config(model, adapter_path: Path):
return False


def _set_native_train_mode(model, actual_model):
"""Ensure wrapped native RL models run with training-time behavior enabled."""
seen = set()
for candidate in (model, actual_model):
if candidate is None or id(candidate) in seen or not hasattr(candidate, "train"):
continue
candidate.train()
seen.add(id(candidate))


class DPOConfig:
"""
Configuration for Direct Preference Optimization training.
Expand Down Expand Up @@ -442,6 +452,7 @@ def _train_native(self):

# Get actual model
actual_model = self.model.model if hasattr(self.model, 'model') else self.model
_set_native_train_mode(self.model, actual_model)

# Create optimizer
lr_schedule = optim.cosine_decay(self.learning_rate, self.iters)
Expand Down Expand Up @@ -681,6 +692,7 @@ def _train_native(self):
print(f"✓ Prepared {len(tokenized_data)} preference pairs")

actual_model = self.model.model if hasattr(self.model, 'model') else self.model
_set_native_train_mode(self.model, actual_model)
lr_schedule = optim.cosine_decay(self.learning_rate, self.iters)
optimizer = optim.AdamW(learning_rate=lr_schedule)

Expand Down Expand Up @@ -995,6 +1007,7 @@ def train(self):
self.model._apply_lora()

actual_model = self.model.model if hasattr(self.model, 'model') else self.model
_set_native_train_mode(self.model, actual_model)
lr_schedule = optim.cosine_decay(self.learning_rate, self.iters)
optimizer = optim.AdamW(learning_rate=lr_schedule)

Expand Down Expand Up @@ -1127,6 +1140,7 @@ def train(self):
print(f"✓ Prepared {len(tokenized_data)} preference pairs")

actual_model = self.model.model if hasattr(self.model, 'model') else self.model
_set_native_train_mode(self.model, actual_model)
lr_schedule = optim.cosine_decay(self.learning_rate, self.iters)
optimizer = optim.AdamW(learning_rate=lr_schedule)

Expand Down
82 changes: 82 additions & 0 deletions tests/test_rl_trainers_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,88 @@ def grpo_dataset():
]


@pytest.mark.integration
@pytest.mark.parametrize(
("trainer_factory", "loss_attr", "dataset_name"),
[
("dpo", "compute_dpo_loss", "preference"),
("orpo", "compute_orpo_loss", "preference"),
("kto", "compute_kto_loss", "kto"),
("simpo", "compute_simpo_loss", "preference"),
],
)
def test_native_rl_trainers_enable_train_mode(
monkeypatch,
mock_tokenizer,
preference_dataset,
kto_dataset,
tmp_path,
trainer_factory,
loss_attr,
dataset_name,
):
"""Native gradient-based RL trainers should switch models into training mode."""
import mlx_tune.rl_trainers as rl_trainers
from mlx_tune import DPOTrainer, DPOConfig, ORPOTrainer, ORPOConfig, KTOTrainer, SimPOTrainer

monkeypatch.setattr(rl_trainers, "HAS_NATIVE_TRAINING", True)
monkeypatch.setattr(rl_trainers, "_save_adapters_and_config", lambda *args, **kwargs: True)

observed = {"training": None}

def pair_loss(model, chosen_ids, rejected_ids, chosen_lengths, rejected_lengths, *args, **kwargs):
observed["training"] = model.training
loss = model(chosen_ids).mean() - model(rejected_ids).mean()
return loss, mx.array(chosen_lengths.sum() + rejected_lengths.sum())

def kto_loss(model, input_ids, lengths, labels, *args, **kwargs):
observed["training"] = model.training
loss = model(input_ids).mean() + labels.mean() * 0.0
return loss, mx.array(lengths.sum())

monkeypatch.setattr(rl_trainers, loss_attr, kto_loss if dataset_name == "kto" else pair_loss)

model = MockModelWrapper(SmallLanguageModel())
mx.eval(model.model.parameters())

if trainer_factory == "dpo":
trainer = DPOTrainer(
model=model,
train_dataset=preference_dataset,
tokenizer=mock_tokenizer,
args=DPOConfig(max_steps=1, logging_steps=1, save_steps=10, output_dir=str(tmp_path / "dpo")),
)
elif trainer_factory == "orpo":
trainer = ORPOTrainer(
model=model,
train_dataset=preference_dataset,
tokenizer=mock_tokenizer,
args=ORPOConfig(max_steps=1, logging_steps=1, save_steps=10, output_dir=str(tmp_path / "orpo")),
)
elif trainer_factory == "kto":
trainer = KTOTrainer(
model=model,
train_dataset=kto_dataset,
tokenizer=mock_tokenizer,
max_steps=1,
logging_steps=1,
output_dir=str(tmp_path / "kto"),
)
else:
trainer = SimPOTrainer(
model=model,
train_dataset=preference_dataset,
tokenizer=mock_tokenizer,
max_steps=1,
logging_steps=1,
output_dir=str(tmp_path / "simpo"),
)

trainer.train()

assert observed["training"] is True


# =============================================================================
# DPO TRAINER INTEGRATION TESTS
# =============================================================================
Expand Down