pytorch-metric-learning icon indicating copy to clipboard operation
pytorch-metric-learning copied to clipboard

Hello , Is this package able to train a multi-label dataset with one-hot encoding?

Open Mahiro2211 opened this issue 1 year ago • 2 comments

It is a great package that improves my efficiency , When I test cifar10, for the one-hot label , I can use

label = torch.argmax(label,dim=1)

to transform one-hot label but When I test it on some one-hot label I can't find a nice method to deal with a multi-label dataset.

at first, I saw this issue it tells me a way to put in multi-label, but I want to further custom it because I need to construct a similarity matrix

label =  torch.matmul(label,label.t())
# For multi-label dataset , if there is one label shared by two samples I mark it as the same

I hope to receive a response from you soon. Thank you.

Mahiro2211 avatar Oct 14 '23 05:10 Mahiro2211

Unfortunately there isn't a way to pass in a custom label comparison function into miners or loss functions. It would be a good idea to add this feature though, so I will keep this issue open.

Edit:

Actually I think you can write a miner to accomplish what you're talking about:

from pytorch_metric_learning.miners import BaseMiner

class CustomMiner(BaseMiner):
    def mine(self, embeddings, labels, ref_emb, ref_labels):
        # compare labels and ref_labels however you want
        # return a tuple (a1, p, a2, n)
        # where (a1, p) are the positive pair indices
        # and (a2, n) are the negative pair indices


miner = CustomMiner()
pairs = miner(embeddings, labels)
loss = loss_fn(embeddings, indices_tuple=pairs)

It's not ideal but it's the only workaround I can think of.

KevinMusgrave avatar Oct 15 '23 23:10 KevinMusgrave

Thank you for your response.I will try it

Mahiro2211 avatar Oct 16 '23 05:10 Mahiro2211