pytorch-metric-learning
pytorch-metric-learning copied to clipboard
What loss is suitable for one anchor, multiple positive and multiple negative?
Apologies for the late reply.
You can use the concept of ref_emb to separate anchors from positives and negatives.
For example, using ContrastiveLoss:
from pytorch_metric_learning.losses import ContrastiveLoss
loss_fn = ContrastiveLoss()
# anchors has shape NxD
# anchor_labels has shape N
# ref_emb has shape MxD
# ref_labels has shape M
loss = loss_fn(anchors, anchor_labels, ref_emb=ref_emb, ref_labels=ref_labels)
Positive pairs will be formed by embeddings in anchors and ref_emb that have the same label.
Negative pairs will be formed by embeddings in anchors and ref_emb that have different labels.
You can have multiple positive pairs and negative pairs for any of the embeddings in anchors. In the extreme case, you could have a single embedding in anchors (shape 1xD), and many positive and negative embeddings in ref_emb.