MONAI
MONAI copied to clipboard
Cannot run V-Net on medical decathlon data
Describe the bug PyTorch complains of size mismatch when using V-Net with medical decathlon data.
To Reproduce
import monai
from monai.apps import DecathlonDataset
from monai.transforms import LoadImaged, EnsureChannelFirstd,ScaleIntensityd, ToTensord, Compose
from monai.networks.nets import VNet
from monai.losses.dice import DiceLoss
import torch
def train_one_epoch(train_loader, loss_fn, optimizer, epoch):
running_loss = 0.
example_ct = 0
for batch_idx, dict_item in enumerate(train_loader):
images = dict_item['image']
labels = dict_item['label']
print("Shape of images", images.shape)
print("Shape of labels", labels.shape)
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = loss_fn(outputs,labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
example_ct += images.size(0)
metrics = {"train/train_loss": loss.item(),
"train/epoch": epoch,
"train/example_ct": example_ct
}
print(metrics)
return running_loss/example_ct
def train_loop(train_loader, val_loader):
loss_fn = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
best_vloss = 1_000_000.
for epoch in range(3):
print(f"Epoch:{epoch+1}")
model.train()
avg_train_loss = train_one_epoch(train_loader, loss_fn, optimizer, epoch)
print("train loss", avg_train_loss)
if __name__=="__main__":
transform = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
ScaleIntensityd(keys="image"),
ToTensord(keys=["image", "label"]),
]
)
train_data = DecathlonDataset(
root_dir="./", task="Task04_Hippocampus", transform=transform, section="validation", seed=12345, download=False
)
model = VNet(spatial_dims=3, in_channels=1, out_channels=1, act='elu')
train_loader = monai.data.DataLoader(
train_data, batch_size=1, num_workers=2, persistent_workers=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")
model.to(device)
train_loop(train_loader, val_loader=None)
Expected behavior Training happens
Screenshots
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 5 for tensor number 1 in the list.
Complete Traceback
Traceback (most recent call last):
File "/home/linn/vnet/train.py", line 65, in <module>
train_loop(train_loader, val_loader=None)
File "/home/linn/vnet/train.py", line 39, in train_loop
avg_train_loss = train_one_epoch(train_loader, loss_fn, optimizer, epoch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/linn/vnet/train.py", line 19, in train_one_epoch
outputs = model(images)
^^^^^^^^^^^^^
File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/networks/nets/vnet.py", line 274, in forward
x = self.up_tr256(out256, out128)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/networks/nets/vnet.py", line 165, in forward
xcat = torch.cat((out, skipxdo), 1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/data/meta_tensor.py", line 282, in __torch_function__
ret = super().__torch_function__(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/_tensor.py", line 1443, in __torch_function__
ret = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
Environment
Ensuring you use the relevant python executable, please paste the output of:
python -c "import monai; monai.config.print_debug_info()"
================================
Printing MONAI config...
================================
MONAI version: 1.3.1
Numpy version: 1.26.4
Pytorch version: 2.3.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 96bfda00c6bd290297f5e3514ea227c6be4d08b4
MONAI __file__: /data/<username>/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/__init__.py
Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: NOT INSTALLED or UNKNOWN VERSION.
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.4
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.
For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
================================
Printing system config...
================================
`psutil` required for `print_system_info`
================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 12.1
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None
cuDNN version: 8902
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA A100-PCIE-40GB
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 108
GPU 0 Total memory (GB): 39.4
GPU 0 CUDA capability (maj.min): 8.0
**Additional context**
Add any other context about the problem here.