torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

Open robmarkcole opened this issue 11 months ago • 8 comments

Description

Relates to https://github.com/Lightning-AI/pytorch-lightning/issues/20456 I believe. Not sure if there is a workaround?

Steps to reproduce

from torchgeo.datasets import BigEarthNet
from torchgeo.datamodules import BigEarthNetDataModule
from torchgeo.trainers import ClassificationTask
from lightning.pytorch import Trainer

# Load the dataset
train_dataset = BigEarthNet(root="data", download=False) # already downloaded

# Setup datamodule
datamodule = BigEarthNetDataModule(bands='all', batch_size=16)
datamodule.setup('fit')
datamodule.setup('test')

# Define the model
num_bands = len(datamodule.maxs)
num_classes = 19  # Default number of classes

task = ClassificationTask(
    weights=True,
    num_classes=num_classes,
    in_channels=num_bands,
    lr=1e-3
)

# Train the model
trainer = Trainer(max_epochs=10)
trainer.fit(task, datamodule)

Returns

File /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/functional.py:3059, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3057 if size_average is not None or reduce is not None:
   3058     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3059 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

RuntimeError: Expected floating point type for target with class probabilities, got Long

Version

0.6.2

robmarkcole avatar Dec 30 '24 11:12 robmarkcole

Tried setting target = target.float() and get

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

robmarkcole avatar Dec 30 '24 13:12 robmarkcole

This error is because BigEarthNet is a multilabel classification dataset but you're using ClassificationTask which is for multiclass problems. Changing ClassificationTask -> MultiLabelClassificationTask should do the trick.

isaaccorley avatar Dec 30 '24 18:12 isaaccorley

Switching to MultiLabelClassificationTask and I am getting RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

Resolved by setting:

        x = batch['image'].to(self.device)

robmarkcole avatar Dec 31 '24 08:12 robmarkcole

Yep, that's definitely https://github.com/Lightning-AI/pytorch-lightning/issues/20456.

adamjstewart avatar Dec 31 '24 12:12 adamjstewart

@robmarkcole potentially also https://github.com/kornia/kornia/issues/3066 ? There was a bug in kornia which would take one of the tensors to CPU and then I encountered a similar error. So I would check in the on_after_batch_transfer function to see whether applying the augmentation is causing the error.

nilsleh avatar Jan 07 '25 10:01 nilsleh

Is this still an issue, or was this fixed in newer versions of our dependencies?

adamjstewart avatar Feb 09 '25 16:02 adamjstewart

I am experiencing the same bug when trying to use the ''SemanticSegmentationTask’’’ reproducing something similar to the LEVIR-CD+ change detection example notebook with a custom dataset. I am using the latest stable release of torchgeo:

torchgeo.__version__, pl.__version__, torch.__version__, kornia.__version__
('0.6.2', '2.4.0', '2.5.1', '0.7.4')

When I pip install the 0.7.0.dev0 version into my conda env I get the same error.

torchgeo.__version__, pl.__version__, torch.__version__, kornia.__version__
('0.7.0.dev0', '2.4.0', '2.5.1', '0.7.4')

It should be noted, that the error also occurs when I get rid of all kornia augmentations in my custom DataModule.

Overwriting the training_step & the validation_step in the SemanticSegmentationTask class as sugeested above (move batch to device) fixes the error.

hkristen avatar Feb 14 '25 08:02 hkristen

Anyone still experiencing this? I've spent much time debugging why I get differences between my legacy code and one of the differences is the use of .to(self.device) introduced to work around this issue

robmarkcole avatar Apr 24 '25 17:04 robmarkcole