pytorch-metric-learning icon indicating copy to clipboard operation
pytorch-metric-learning copied to clipboard

Add wrapper for self supervised loss

Open KevinMusgrave opened this issue 3 years ago • 2 comments

A common use case is to have embeddings and ref_emb be augmented versions of each other. For most losses right now you have to create labels to indicate which embeddings correspond with which ref_emb. A wrapper that does this for the user would be nice.

Something like:

loss_fn = SelfSupervisedWrapper(TripletMarginLoss())
loss = loss_fn(embeddings, ref_emb)

KevinMusgrave avatar Jan 01 '22 04:01 KevinMusgrave

I've actually been meaning to make an implementation for sort of an extended version of this after reading the paper What Should Not Be Contrastive in Contrastive Learning.

For example if you had augmentation transformations AUG1, AUG2, and AUG3 where:

AUG_1 dealt with random croppping AUG_2 dealt with color jittering AUG_3 dealt with texture randomization

typically you'd conduct contrastive learning so that you train an encoder model E(I) = X so that

# AUG_i_j = AUG_i(AUG_j)

E(I) == E(AUG_1_2_3(I)))

in order to be invariant to arbitrary attributes of input images to make it focus on the deep semantics of them.

But there are cases when the downstream task requires the model to be sensitive to those attributes and the model's trained invariance towards them hurts performance.

So the paper suggests having a backbone encoder B(I)=X with multiple heads E1(X)=X1, E2(X)=X2, E3(X)=X3 and train:

X == E1(B(AUG_1_2(I)) == E2(B(AUG_1_3(I)) == E3(B(AUG_2_3))

And a nice way to do this sort of "multi-view contrastive learning" would be to do:

loss_fn = SelfSupervisedWrapper(TripletMarginLoss())
loss = loss_fn(embeddings, ref_emb_1, ref_emb_2, ref_emb_3)

where ref_emb_1, ref_emb_2, ref_emb_3 are X1, X2, X3, respectively, from the above example.

I do feel like this would be a useful feature without much drawbacks to what you just suggested because, if its made so that it supports an arbitrary number of views, it surely supports:

loss_fn = SelfSupervisedWrapper(TripletMarginLoss())
loss = loss_fn(embeddings, ref_emb)

Let me know what you think!

cwkeam avatar Jan 01 '22 05:01 cwkeam

Yes I think supporting an arbitrary number of views (ref_emb1, ref_emb2, ...) is a good idea 👍

KevinMusgrave avatar Jan 01 '22 10:01 KevinMusgrave

losses.SelfSupervisedLoss is now available in v2.0.0.

The multiple reference embeddings aren't yet available, as I'm not sure what the best approach is for computing the loss. I've created an issue for this though #580

KevinMusgrave avatar Jan 29 '23 18:01 KevinMusgrave