pytorch-metric-learning icon indicating copy to clipboard operation
pytorch-metric-learning copied to clipboard

Using TripletMarginLoss and MultiSimilarityMiner with labeled triplet dataset

Open puzzlecollector opened this issue 2 years ago • 7 comments

@KevinMusgrave I have a question about using pytorch-metric-learning with a labeled triplet dataset.

I have around one million (query, positive, negative) triplet for training. Now, looking at the example code usage

from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner()
loss_func = losses.TripletMarginLoss()

# your training loop
for i, (data, labels) in enumerate(dataloader):
	optimizer.zero_grad()
	embeddings = model(data)
	hard_pairs = miner(embeddings, labels)
	loss = loss_func(embeddings, labels, hard_pairs)
	loss.backward()
	optimizer.step()

It appears that I need the embeddings and the labels. And from my understanding, embeddings with the same labels are going to be formed as Anchor-Positive pairs and embeddings with different labels are going to be formed as Anchor-Negative pairs when the triplet margin loss is calculated. From this, I implemented my dataloader such that it returns the following:

Suppose my batch size is 5. Then because my data is a labeled triplet, I have

(q1, p1, n1), (q2,p2,n2), (q3,p3,n3), (q4,p4,n4), (q5,p5,n5)

where qi, pi, ni (1 <= i <= 5) are query, positive and negative respectively. When forming labels, I know for sure that qi and pi are positive, but I am not so sure (although it is probably the case) that qi and qj, pj, nj where i does not equal j are negative pairs. Regardless, I form the labels as follows

q1, p1 both have label 0 and n1 has label 1 q2, p2 both have label 2 and n2 has label 3 ... q5, p5 both have label 8 and n5 has label 9

in general, I can say that qi and pi have label 2i and ni has label 2i + 1

With these assignment of labels, I train my network and it seems to train fine and it seems to converge much faster (probably due to more formations of triplets?) than when I don't use the above framework and simply calculate the triplet margin loss for only the batch number of samples for every forward.

I would really appreciate if you can let me know if I have assigned the labels correctly or not. Thanks! :)

puzzlecollector avatar Aug 08 '22 09:08 puzzlecollector

Yes I think your labels are correct.

With these assignment of labels, I train my network and it seems to train fine and it seems to converge much faster (probably due to more formations of triplets?)

That sounds reasonable. In your example there will be 25 triplets instead of 5. The miner may also be helping convergence.

KevinMusgrave avatar Aug 08 '22 12:08 KevinMusgrave

@KevinMusgrave about the miner, does it automatically choose some of the triplets (i.e. hard negative triplets) instead of using all of the triplets that can be formed in the batch for more effective training?

puzzlecollector avatar Aug 08 '22 14:08 puzzlecollector

Yes, though it's a bit more complicated because MultiSimilarityMiner finds hard pairs. Those pairs are converted to triplets inside of TripletMarginLoss, by combining positive and negative pairs that have the same anchor. If you want something that specifically finds hard triplets, you can use TripletMarginMiner.

KevinMusgrave avatar Aug 08 '22 15:08 KevinMusgrave

@KevinMusgrave This is just a generic question: Since the miners tend to select hard pairs or triplets instead of showing the model all possible pairs/triplets that can be formed, is it possible that not using any miners could be more beneficial for training and performance? I am aware that hard triplets are important for training, but perhaps showing more training examples per batch rather than limiting it to only hard samples could be more beneficial for training. As such I am thinking of the following train loop:

  1. First train with only TripletMarginLoss provided by pytorch metric learning and supply the model with as many examples as possible. Train for like 1-2 epochs.
  2. Then repeat the training, but this time with TripletMarginMiner so that the model specifically gets to learn harder examples.

Does this sound reasonable or do you have any recommendations? Thanks :)

some more background on my problem: I am training large language models like DeBERTa-Large and BigBird on USPTO patent claims dataset for patent similarity search, and because the language models are large and my computing resources are limited, I can fit very small minibatch sizes in my GPU (<=10 batch size). It seems like the models have a difficult time learning with small batch sizes.

puzzlecollector avatar Aug 09 '22 07:08 puzzlecollector

Does this sound reasonable or do you have any recommendations? Thanks :)

Sounds reasonable. Actually I would try skipping step 2 and see how effective the loss is by itself. I would use ContrastiveLoss, MultiSimilarityLoss, or NTXentLoss (also known as InfoNCE) because they tend to perform better than TripletMarginLoss.

some more background on my problem: I am training large language models like DeBERTa-Large and BigBird on USPTO patent claims dataset for patent similarity search, and because the language models are large and my computing resources are limited, I can fit very small minibatch sizes in my GPU (<=10 batch size). It seems like the models have a difficult time learning with small batch sizes.

I don't have experience with really small batch sizes. Maybe gradient accumulation could help?

KevinMusgrave avatar Aug 09 '22 10:08 KevinMusgrave

@KevinMusgrave I see. Yes I am attempting gradient accumulation.

I will try the three losses you mentioned. If I have a checkpoint that is already trained with TripletMarginLoss and MultiSimilarityMiner, instead of starting the training from the beginning with randomly initialized weights, could it be better for me to do continual learning with the previously generated checkpoint?

So previously let's say I have a model trained with TripletMarginLoss + MultiSimilarityMiner. Then can I use this checkpoint to train it with ContrastiveLoss this time? In your opinion do you think this speeds up training or would it confuse the model? Thank you always for your kind answers :)

And another question - I am currently reading up on the paper for the lifted structure loss which is also implemented in pytorch-metric-learning and I wonder if this is a loss function that you would recommend for my problem? (did you get to experiment with this loss in the past and if so did it work well?)

puzzlecollector avatar Aug 10 '22 02:08 puzzlecollector

So previously let's say I have a model trained with TripletMarginLoss + MultiSimilarityMiner. Then can I use this checkpoint to train it with ContrastiveLoss this time? In your opinion do you think this speeds up training or would it confuse the model? Thank you always for your kind answers :)

I've never tried this, but yes I think it would be faster than starting from randomly initialized weights.

And another question - I am currently reading up on the paper for the lifted structure loss which is also implemented in pytorch-metric-learning and I wonder if this is a loss function that you would recommend for my problem? (did you get to experiment with this loss in the past and if so did it work well?)

I haven't experimented with it.

KevinMusgrave avatar Aug 10 '22 10:08 KevinMusgrave