pytorch-metric-learning icon indicating copy to clipboard operation
pytorch-metric-learning copied to clipboard

How to transfer NTXent loss for segmentation task?

Open Kyfafyd opened this issue 4 years ago • 8 comments

Kyfafyd avatar Nov 11 '20 13:11 Kyfafyd

I'm guessing you have an embedding and a label for each pixel in the image. You can pass all of these embeddings to NTXentLoss:

from pytorch_metric_learning.losses import NTXentLoss
loss_fn = NTXentLoss()
pixel_loss = loss_fn(embeddings, labels) 

However, since there are so many pixels in an image, you will probably run out of memory. So you can try randomly sampling a reasonable number of triplets, and passing those into the loss function.

from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
indices_tuple = lmu.get_random_triplet_indices(labels, t_per_anchor = 1)
pixel_loss = loss_fn(embeddings, labels, indices_tuple)

t_per_anchor means triplets per anchor. So the larger you make that, the higher the memory consumption will be.

KevinMusgrave avatar Nov 11 '20 21:11 KevinMusgrave

Thanks for your comment! yes, I have an embedding and a label for each pixel in the image. Meanwhile, total classes is 5. batch_size is 4 the embedding shape is torch.Size([4, 32, 384, 384]), the label shape is torch.Size([4, 384, 384]). But I got an error like this:

Traceback (most recent call last): File "train.py", line 198, in main() File "train.py", line 174, in main train_loss, train_dices = train(model, train_loader, optimizer, LOSS_FUNC, lr_sheduler, device) File "train.py", line 51, in train contrastive_loss = contrastive_loss_func(contrastive_feature, label) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in call_impl result = self.forward(*input, **kwargs) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/losses/base_metric_loss_function.py", line 32, in forward loss_dict = self.compute_loss(embeddings, labels, indices_tuple) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/losses/generic_pair_loss.py", line 14, in compute_loss indices_tuple = lmu.convert_to_pairs(indices_tuple, labels) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 57, in convert_to_pairs return get_all_pairs_indices(labels) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 41, in get_all_pairs_indices matches.fill_diagonal(0) RuntimeError: all dimensions of input must be of equal length

I am wondering how to solve this?

Kyfafyd avatar Nov 12 '20 03:11 Kyfafyd

You need to reshape the embeddings to have shape (N, D), and labels to have shape (N,).

Something like this might work, though I haven't confirmed that the reshaping of embeddings matches the reshaping of labels.

embeddings = embeddings.permute(0,2,3,1)
embeddings = embeddings.contiguous().view(-1, 32)
labels = labels.view(-1)

KevinMusgrave avatar Nov 12 '20 23:11 KevinMusgrave

Thanks very much for help! As you have said, I run out of memory. but I meet this problem when trying randomly sampling a reasonable number of triplets.

Traceback (most recent call last): File "train.py", line 212, in main() File "train.py", line 186, in main train_loss, train_dices = train(model, train_loader, optimizer, LOSS_FUNC, lr_sheduler, device) File "train.py", line 51, in train indices_tuple = lmu.get_random_triplet_indices(label, t_per_anchor=1) File "/research/dept8/qdou/zwang/anaconda3/lib/python3.8/site-packages/pytorch_metric_learning/utils/loss_and_miner_utils.py", line 119, in get_random_triplet_indices p_inds_ = p_inds_[~torch.eye(n_a).bool()].view((n_a, n_a - 1)) RuntimeError: [enforce fail at CPUAllocator.cpp:64] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 1322049238416 bytes. Error code 12 (Cannot allocate memory)

my code is:

contrastive_loss_func = NTXentLoss(temperature=0.1)
contrastive_feature = contrastive_feature.permute(0, 2, 3, 1)
contrastive_feature = contrastive_feature.contiguous().view(-1, 32)
label = label.view(-1)
indices_tuple = lmu.get_random_triplet_indices(label, t_per_anchor=1)
contrastive_loss = contrastive_loss_func(contrastive_feature, label, indices_tuple)

Kyfafyd avatar Nov 13 '20 04:11 Kyfafyd

Hmm I see, because the batch size is huge (589000), that function isn't able to create the necessary matrices.

I'll have to think about how to solve this large-batch problem. In the meantime, I think the only workaround would be to randomly sample pixels, to reduce the batch size.

KevinMusgrave avatar Nov 13 '20 18:11 KevinMusgrave

Thanks a lot! I have sampled pixels randomly for training. But it seems not work. Looking forward to your repo update!

Kyfafyd avatar Nov 14 '20 03:11 Kyfafyd

Hi dear author, have this issue updated recently?

Kyfafyd avatar Jan 26 '21 11:01 Kyfafyd

Sorry, I haven't gotten around to this yet.

KevinMusgrave avatar Jan 26 '21 20:01 KevinMusgrave