composer icon indicating copy to clipboard operation
composer copied to clipboard

Generalized Mixup

Open A-Jacobson opened this issue 3 years ago • 3 comments

Our current implementation of mixup works directly on targets.. that's why we need a dense loss, or to convert class indexes to 1 hot, like so:

 if check_for_index_targets(y):
     y_onehot = F.one_hot(y, num_classes=n_classes)
     y_shuffled_onehot = F.one_hot(y_shuffled, num_classes=n_classes)
     y_mix = ((1. - interpolation_lambda) * y_onehot + interpolation_lambda * y_shuffled_onehot)
 else:
     y_mix = ((1. - interpolation_lambda) * y + interpolation_lambda * y_shuffled)

but we can mix the output of arbitrary loss functions after we split the batch.

def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

This has the same effect, allows us to work with arbitrary losses (or even combinations of losses, and is more faithful to the original implementation seen here: https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py

though, this does require us to be able to switch out the model.criterion with mixup_criterion and changes the format of the batch. @Landanjs thoughts?

A-Jacobson avatar Feb 04 '22 08:02 A-Jacobson

@coryMosaicML you know all the things about this?

A-Jacobson avatar Feb 04 '22 18:02 A-Jacobson

A few thoughts:

  • This interpolates losses rather than interpolating class labels. Either way will give the same result for cross entropy because cross_entropy(input, lam * target + (1 - lam) * target_2) = lam * cross_entropy(input, target_1) + (1 - lam) * cross_entropy(input, target_2), but for other cost functions (incl. MSE) the two ways will give different results. Perhaps it's ok to only support cross entropy for mixup. We could also simply define mixup for other cost functions to also behave this way, though it's somewhat different from the way it's usually described.
  • We'd also need to make sure this will compose correctly with other algorithms that do similar things to the loss such as label smoothing and cutmix. In the special case of cross entropy it should be possible to implement these in a similar way with no weird ordering effects, though bookkeeping might become complex for combinations like cutmix+mixup. More complex label smoothing strategies (like the self smoothing we explored before) will have issues.
  • This will use less memory, at the cost of computing the cost function multiple times.

Overall I think this is probably a good idea and will save some headaches. We'll need to be careful how we handle something like a composition of label smoothing, mixup, and cutmix as it will be easy to introduce weird ordering effects there. We should also make it clear that this is intended for cross entropy losses. Perhaps we can keep the old 'dense' label behavior available as an option with something like dense_labels=True which is False by default.

coryMosaicML avatar Feb 07 '22 22:02 coryMosaicML

Like Cory said, the problem you end up with when missing with the loss function is that for almost anything but xent, you can't just do a convex combination of the losses.

To compose with a custom loss or other losses, you need to take in the other loss function as an argument and call that, rather than hardcoding F.cross_entropy(). This creates weird plumbing problems. Like, the mixup algorithm needs to be passed a callable, and then if you also want to label smooth, you need to then pass the mixup algorithm's loss functionas a callable to the label smoothing loss function, and so on.

You also hit subtleties with the user wanting to add loss terms that aren't just a function of the labels (e.g., an MoE balancing loss, or a multitask loss). So you might not be able to safely pass in an overall loss to the algorithm, and instead need the user to pull out just the part that depends on the labels. (or maybe not if all of these losses are convex combos of other losses?)

So I won't object to having a special cased implementation that works for the common case of cross entropy, but FWIW trying to enable composition with the loss-based implementations in the first research repo was a nightmare and I gave up on it pretty fast. (But it was easy once I switched stuff to just messing with the labels).

dblalock avatar Feb 07 '22 23:02 dblalock

Closing. Tracking elsewhere as low pri

mvpatel2000 avatar Jun 22 '23 21:06 mvpatel2000