mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Reinforcement Learning from Human Feedback (RLHF) examples: Direct Preference Optimization (DPO)

Open danilopeixoto opened this issue 5 months ago • 7 comments

Introduce one Reinforcement Learning from Human Feedback (RLHF) example, such as Direct Preference Optimization (DPO) method.

Paper

Direct Preference Optimization: Your Language Model is Secretly a Reward Model

Notes

Direct Preference Optimization (DPO): A Simplified Explanation by João Lages

Implementation examples

Possible MLX implementation

Policy and reference log probabilities:

def get_batched_logps(model, inputs, targets):
    logits, _ = model(inputs)
    logits = logits.astype(mx.float32)

    loss_mask = targets != 0
    per_token_logps = mx.take_along_axis(nn.log_softmax(logits), targets[..., None], axis=2).squeeze(2)

    return tuple((per_token_logps * loss_mask).sum(-1).split(2))

Loss:

def dpo_loss(model, beta, label_smoothing, reference_chosen_logps, reference_rejected_logps, inputs, targets):
    chosen_logps, rejected_logps = get_batched_logps(model, inputs, targets)

    pi_logratios = chosen_logps - rejected_logps
    reference_logratios = reference_chosen_logps - reference_rejected_logps

    logits = pi_logratios - reference_logratios
    losses = -nn.log_sigmoid(beta * logits) * (1.0 - label_smoothing) - nn.log_sigmoid(-beta * logits) * label_smoothing

    chosen_rewards = beta * (chosen_logps - reference_chosen_logps)
    rejected_rewards = beta * (rejected_logps - reference_rejected_logps)
    reward_accuracies = (chosen_rewards > rejected_rewards).astype(mx.float32)
    reward_margins = chosen_rewards - rejected_rewards

    ntoks = (inputs != 0).sum()

    return (
        losses.mean(),
        chosen_rewards.mean(),
        rejected_rewards.mean(),
        reward_accuracies.mean(),
        reward_margins.mean(),
        ntoks,
    )

Beta: The temperature parameter for the DPO loss is typically set in the range of 0.1 to 0.5. The reference model is ignored when beta equals 0.

Label smoothing: This parameter represents the conservativeness for DPO loss, assuming that preferences are noisy and can be flipped with a probability of label_smoothing.

Note label_smoothing > 0 defines the Conservative DPO loss.

danilopeixoto avatar Mar 01 '24 12:03 danilopeixoto

@danilopeixoto I've been thinking about having this in MLX LM recently. Any interest in sending a PR?

It might make to do it after we have a more manageable config (https://github.com/ml-explore/mlx-examples/pull/503) but that should be landed soon!

awni avatar Mar 01 '24 14:03 awni

To be more concrete, I'm envisioning you just set the loss in the config. e.g. cross_entropy or dpo

awni avatar Mar 01 '24 14:03 awni

This would be an awesome addition to mlx_examples! 🔥

ivanfioravanti avatar Mar 19 '24 11:03 ivanfioravanti

I'm very very excited for this! Don't have the technical expertise to implement the DPO directly but would love to help in other ways (config, code cleanup) if neccessary!

N8python avatar Mar 26 '24 15:03 N8python

That makes MLX really useful for production not just a research tool!

lin72h avatar Mar 27 '24 02:03 lin72h

+500 waiting for this

kishoretvk avatar Apr 11 '24 11:04 kishoretvk

Wait for this, when will the DPO training be supported?

developerlin avatar May 16 '24 15:05 developerlin