pytorch-metric-learning
pytorch-metric-learning copied to clipboard
Add wrapper for self supervised loss
A common use case is to have embeddings and ref_emb be augmented versions of each other. For most losses right now you have to create labels to indicate which embeddings correspond with which ref_emb. A wrapper that does this for the user would be nice.
Something like:
loss_fn = SelfSupervisedWrapper(TripletMarginLoss())
loss = loss_fn(embeddings, ref_emb)
I've actually been meaning to make an implementation for sort of an extended version of this after reading the paper What Should Not Be Contrastive in Contrastive Learning.
For example if you had augmentation transformations AUG1, AUG2, and AUG3 where:
AUG_1 dealt with random croppping AUG_2 dealt with color jittering AUG_3 dealt with texture randomization
typically you'd conduct contrastive learning so that you train an encoder model E(I) = X so that
# AUG_i_j = AUG_i(AUG_j)
E(I) == E(AUG_1_2_3(I)))
in order to be invariant to arbitrary attributes of input images to make it focus on the deep semantics of them.
But there are cases when the downstream task requires the model to be sensitive to those attributes and the model's trained invariance towards them hurts performance.
So the paper suggests having a backbone encoder B(I)=X with multiple heads E1(X)=X1, E2(X)=X2, E3(X)=X3 and train:
X == E1(B(AUG_1_2(I)) == E2(B(AUG_1_3(I)) == E3(B(AUG_2_3))
And a nice way to do this sort of "multi-view contrastive learning" would be to do:
loss_fn = SelfSupervisedWrapper(TripletMarginLoss())
loss = loss_fn(embeddings, ref_emb_1, ref_emb_2, ref_emb_3)
where ref_emb_1, ref_emb_2, ref_emb_3 are X1, X2, X3, respectively, from the above example.
I do feel like this would be a useful feature without much drawbacks to what you just suggested because, if its made so that it supports an arbitrary number of views, it surely supports:
loss_fn = SelfSupervisedWrapper(TripletMarginLoss())
loss = loss_fn(embeddings, ref_emb)
Let me know what you think!
Yes I think supporting an arbitrary number of views (ref_emb1, ref_emb2, ...) is a good idea 👍
losses.SelfSupervisedLoss is now available in v2.0.0.
The multiple reference embeddings aren't yet available, as I'm not sure what the best approach is for computing the loss. I've created an issue for this though #580