pytorch-metric-learning
pytorch-metric-learning copied to clipboard
How to transfer NTXent loss for segmentation task?
I'm guessing you have an embedding and a label for each pixel in the image. You can pass all of these embeddings to NTXentLoss:
from pytorch_metric_learning.losses import NTXentLoss
loss_fn = NTXentLoss()
pixel_loss = loss_fn(embeddings, labels)
However, since there are so many pixels in an image, you will probably run out of memory. So you can try randomly sampling a reasonable number of triplets, and passing those into the loss function.
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
indices_tuple = lmu.get_random_triplet_indices(labels, t_per_anchor = 1)
pixel_loss = loss_fn(embeddings, labels, indices_tuple)
t_per_anchor
means triplets per anchor. So the larger you make that, the higher the memory consumption will be.
Thanks for your comment!
yes, I have an embedding and a label for each pixel in the image. Meanwhile, total classes is 5
. batch_size is 4
the embedding shape is torch.Size([4, 32, 384, 384])
, the label shape is torch.Size([4, 384, 384])
.
But I got an error like this:
Traceback (most recent call last): File "train.py", line 198, in
main() File "train.py", line 174, in main train_loss, train_dices = train(model, train_loader, optimizer, LOSS_FUNC, lr_sheduler, device) File "train.py", line 51, in train contrastive_loss = contrastive_loss_func(contrastive_feature, label) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in call_impl result = self.forward(*input, **kwargs) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/losses/base_metric_loss_function.py", line 32, in forward loss_dict = self.compute_loss(embeddings, labels, indices_tuple) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/losses/generic_pair_loss.py", line 14, in compute_loss indices_tuple = lmu.convert_to_pairs(indices_tuple, labels) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 57, in convert_to_pairs return get_all_pairs_indices(labels) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 41, in get_all_pairs_indices matches.fill_diagonal(0) RuntimeError: all dimensions of input must be of equal length
I am wondering how to solve this?
You need to reshape the embeddings to have shape (N, D), and labels to have shape (N,).
Something like this might work, though I haven't confirmed that the reshaping of embeddings
matches the reshaping of labels
.
embeddings = embeddings.permute(0,2,3,1)
embeddings = embeddings.contiguous().view(-1, 32)
labels = labels.view(-1)
Thanks very much for help! As you have said, I run out of memory. but I meet this problem when trying randomly sampling a reasonable number of triplets.
Traceback (most recent call last): File "train.py", line 212, in
main() File "train.py", line 186, in main train_loss, train_dices = train(model, train_loader, optimizer, LOSS_FUNC, lr_sheduler, device) File "train.py", line 51, in train indices_tuple = lmu.get_random_triplet_indices(label, t_per_anchor=1) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 119, in get_random_triplet_indices p_inds_ = p_inds_[~torch.eye(n_a).bool()].view((n_a, n_a - 1)) RuntimeError: [enforce fail at CPUAllocator.cpp:64] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 1322049238416 bytes. Error code 12 (Cannot allocate memory)
my code is:
contrastive_loss_func = NTXentLoss(temperature=0.1)
contrastive_feature = contrastive_feature.permute(0, 2, 3, 1)
contrastive_feature = contrastive_feature.contiguous().view(-1, 32)
label = label.view(-1)
indices_tuple = lmu.get_random_triplet_indices(label, t_per_anchor=1)
contrastive_loss = contrastive_loss_func(contrastive_feature, label, indices_tuple)
Hmm I see, because the batch size is huge (589000), that function isn't able to create the necessary matrices.
I'll have to think about how to solve this large-batch problem. In the meantime, I think the only workaround would be to randomly sample pixels, to reduce the batch size.
Thanks a lot! I have sampled pixels randomly for training. But it seems not work. Looking forward to your repo update!
Hi dear author, have this issue updated recently?
Sorry, I haven't gotten around to this yet.