mlx-examples
mlx-examples copied to clipboard
Reinforcement Learning from Human Feedback (RLHF) examples: Direct Preference Optimization (DPO)
Introduce one Reinforcement Learning from Human Feedback (RLHF) example, such as Direct Preference Optimization (DPO) method.
Paper
Direct Preference Optimization: Your Language Model is Secretly a Reward Model
Notes
Direct Preference Optimization (DPO): A Simplified Explanation by João Lages

Implementation examples
- huggingface/trl: TRL - Transformer Reinforcement Learning
- eric-mitchell/direct-preference-optimization: Direct Preference Optimization
Possible MLX implementation
Policy and reference log probabilities:
def get_batched_logps(model, inputs, targets):
logits, _ = model(inputs)
logits = logits.astype(mx.float32)
loss_mask = targets != 0
per_token_logps = mx.take_along_axis(nn.log_softmax(logits), targets[..., None], axis=2).squeeze(2)
return tuple((per_token_logps * loss_mask).sum(-1).split(2))
Loss:
def dpo_loss(model, beta, label_smoothing, reference_chosen_logps, reference_rejected_logps, inputs, targets):
chosen_logps, rejected_logps = get_batched_logps(model, inputs, targets)
pi_logratios = chosen_logps - rejected_logps
reference_logratios = reference_chosen_logps - reference_rejected_logps
logits = pi_logratios - reference_logratios
losses = -nn.log_sigmoid(beta * logits) * (1.0 - label_smoothing) - nn.log_sigmoid(-beta * logits) * label_smoothing
chosen_rewards = beta * (chosen_logps - reference_chosen_logps)
rejected_rewards = beta * (rejected_logps - reference_rejected_logps)
reward_accuracies = (chosen_rewards > rejected_rewards).astype(mx.float32)
reward_margins = chosen_rewards - rejected_rewards
ntoks = (inputs != 0).sum()
return (
losses.mean(),
chosen_rewards.mean(),
rejected_rewards.mean(),
reward_accuracies.mean(),
reward_margins.mean(),
ntoks,
)
Beta: The temperature parameter for the DPO loss is typically set in the range of 0.1 to 0.5. The reference model is ignored when beta equals 0.
Label smoothing: This parameter represents the conservativeness for DPO loss, assuming that preferences are noisy and can be flipped with a probability of label_smoothing.
Note
label_smoothing > 0defines the Conservative DPO loss.
@danilopeixoto I've been thinking about having this in MLX LM recently. Any interest in sending a PR?
It might make to do it after we have a more manageable config (https://github.com/ml-explore/mlx-examples/pull/503) but that should be landed soon!
To be more concrete, I'm envisioning you just set the loss in the config. e.g. cross_entropy or dpo
This would be an awesome addition to mlx_examples! 🔥
I'm very very excited for this! Don't have the technical expertise to implement the DPO directly but would love to help in other ways (config, code cleanup) if neccessary!
That makes MLX really useful for production not just a research tool!
+500 waiting for this
Wait for this, when will the DPO training be supported?