pytorch-metric-learning
pytorch-metric-learning copied to clipboard
Hello , Is this package able to train a multi-label dataset with one-hot encoding?
It is a great package that improves my efficiency , When I test cifar10, for the one-hot label , I can use
label = torch.argmax(label,dim=1)
to transform one-hot label but When I test it on some one-hot label I can't find a nice method to deal with a multi-label dataset.
at first, I saw this issue it tells me a way to put in multi-label, but I want to further custom it because I need to construct a similarity matrix
label = torch.matmul(label,label.t())
# For multi-label dataset , if there is one label shared by two samples I mark it as the same
I hope to receive a response from you soon. Thank you.
Unfortunately there isn't a way to pass in a custom label comparison function into miners or loss functions. It would be a good idea to add this feature though, so I will keep this issue open.
Edit:
Actually I think you can write a miner to accomplish what you're talking about:
from pytorch_metric_learning.miners import BaseMiner
class CustomMiner(BaseMiner):
def mine(self, embeddings, labels, ref_emb, ref_labels):
# compare labels and ref_labels however you want
# return a tuple (a1, p, a2, n)
# where (a1, p) are the positive pair indices
# and (a2, n) are the negative pair indices
miner = CustomMiner()
pairs = miner(embeddings, labels)
loss = loss_fn(embeddings, indices_tuple=pairs)
It's not ideal but it's the only workaround I can think of.
Thank you for your response.I will try it