SlicerMONAIViz icon indicating copy to clipboard operation
SlicerMONAIViz copied to clipboard

monai UNET training error , diffreneces in tensor size

Open Alisoltan82 opened this issue 1 year ago • 1 comments

good day,

tensor size for image , label checked from train dataloader just before the training

data = first(train_loader)
data['image'][6].shape , data['label'][6].shape

(torch.Size([1, 90, 90, 40]), torch.Size([1, 90, 90, 40]))

Model:

device = torch.device("cuda")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

Error:

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 6 for tensor number 1 in the list.

any ideas on what to check? and what are the 5 and 6 dimensions?

Note: labels have background + 2 colors not 1 Thanks

Alisoltan82 avatar Jan 22 '24 17:01 Alisoltan82

the full error message after use of divisible padding, k = 64 and after epoch 2

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File <timed exec>:44

File /opt/conda/lib/python3.10/site-packages/monai/inferers/utils.py:229, in sliding_window_inference(inputs, roi_size, sw_batch_size, predictor, overlap, mode, sigma_scale, padding_mode, cval, sw_device, device, progress, roi_weight_map, process_fn, buffer_steps, buffer_dim, with_coord, *args, **kwargs)
    227     seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs)  # batched patch
    228 else:
--> 229     seg_prob_out = predictor(win_data, *args, **kwargs)  # batched patch
    231 # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
    232 dict_keys, seg_tuple = _flatten_struct(seg_prob_out)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/monai/networks/nets/unet.py:300, in UNet.forward(self, x)
    299 def forward(self, x: torch.Tensor) -> torch.Tensor:
--> 300     x = self.model(x)
    301     return x

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
    215 def forward(self, input):
    216     for module in self:
--> 217         input = module(input)
    218     return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/monai/networks/layers/simplelayers.py:129, in SkipConnection.forward(self, x)
    128 def forward(self, x: torch.Tensor) -> torch.Tensor:
--> 129     y = self.submodule(x)
    131     if self.mode == "cat":
    132         return torch.cat([x, y], dim=self.dim)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
    215 def forward(self, input):
    216     for module in self:
--> 217         input = module(input)
    218     return input

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/monai/networks/layers/simplelayers.py:132, in SkipConnection.forward(self, x)
    129 y = self.submodule(x)
    131 if self.mode == "cat":
--> 132     return torch.cat([x, y], dim=self.dim)
    133 if self.mode == "add":
    134     return torch.add(x, y)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 11 but got size 12 for tensor number 1 in the list.

Alisoltan82 avatar Jan 23 '24 09:01 Alisoltan82