segmentation_models_pytorch_3d icon indicating copy to clipboard operation
segmentation_models_pytorch_3d copied to clipboard

Tried to use timm-3d encoders, but

Open KeesariVigneshwarReddy opened this issue 6 months ago • 0 comments

Program

import segmentation_models_pytorch_3d as smp
import torch

encoder_name = 'tu-maxvit_base_tf_224.in21k'
model = smp.Unet(
    encoder_name=encoder_name,
    encoder_weights=None,
    in_channels=3,
    classes=1,
)
model(torch.rand(4,3,96,96,96)).shape

Error logs

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_31/4109723633.py in <cell line: 0>()
----> 1 model(torch.rand(4,3,96,96,96)).shape

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/segmentation_models_pytorch_3d/base/model.py in forward(self, x)
     46         self.check_input_shape(x)
     47 
---> 48         features = self.encoder(x)
     49         decoder_output = self.decoder(*features)
     50 

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/segmentation_models_pytorch_3d/encoders/timm_universal.py in forward(self, x)
     28 
     29     def forward(self, x):
---> 30         features = self.model(x)
     31         features = [
     32             x,

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/timm_3d/models/_features.py in forward(self, x)
    280 
    281     def forward(self, x) -> (List[torch.Tensor]):
--> 282         return list(self._collect(x).values())
    283 
    284 

/usr/local/lib/python3.11/dist-packages/timm_3d/models/_features.py in _collect(self, x)
    234                 x = module(x) if first_or_last_module else checkpoint(module, x)
    235             else:
--> 236                 x = module(x)
    237 
    238             if name in self.return_layers:

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in forward(self, x)
   1097             x = checkpoint_seq(self.blocks, x)
   1098         else:
-> 1099             x = self.blocks(x)
   1100         return x
   1101 

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/container.py in forward(self, input)
    248     def forward(self, input):
    249         for module in self:
--> 250             input = module(input)
    251         return input
    252 

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in forward(self, x)
    987             x = x.permute(0, 2, 3, 4, 1)  # to NHWDC (channels-last)
    988         if self.attn_block is not None:
--> 989             x = self.attn_block(x)
    990         x = self.attn_grid(x)
    991         if not self.nchwd_attn:

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in forward(self, x)
    765         tmp = self.norm1(x)
    766         # print("!K!", tmp.shape)
--> 767         tmp2 = self._partition_attn(tmp)
    768         # print("!L!", tmp2.shape)
    769         x = x + self.drop_path1(self.ls1(tmp2))

/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in _partition_attn(self, x)
    746         if self.partition_block:
    747             # print('W part', img_size)
--> 748             partitioned = window_partition(x, self.partition_size)
    749             # print(partitioned.shape)
    750         else:

/usr/local/lib/python3.11/dist-packages/timm_3d/models/maxxvit.py in window_partition(x, window_size)
    649 def window_partition(x, window_size: List[int]):
    650     B, H, W, D, C = x.shape
--> 651     _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
    652     _assert(W % window_size[1] == 0, f'height ({W}) must be divisible by window ({window_size[1]})')
    653     _assert(D % window_size[2] == 0, f'height ({D}) must be divisible by window ({window_size[2]})')

/usr/local/lib/python3.11/dist-packages/torch/__init__.py in _assert(condition, message)
   2038             _assert, (condition,), condition, message
   2039         )
-> 2040     assert condition, message
   2041 
   2042 

AssertionError: height (3) must be divisible by window (2)

KeesariVigneshwarReddy avatar Apr 21 '25 04:04 KeesariVigneshwarReddy