torchio icon indicating copy to clipboard operation
torchio copied to clipboard

GridAggregator does not support smaller output than input

Open wahabk opened this issue 2 years ago • 1 comments

Is there an existing issue for this?

  • [X] I have searched the existing issues

Bug summary

tio.inference.GridSampler and GridAggregator do not allow the model output to be smaller than the input.

I was going to submit this as a feature request before making a PR, however, I realised that tio actually supports this depending on patch_overlap and overlap_mode so I believe this should be a bug.

Code for reproduction

# This is not a MWE but a test named `test_inference_smaller.py`

from torch.utils.data import DataLoader
from torchio import DATA
from torchio import LOCATION
from torchio.data.inference import GridAggregator
from torchio.data.inference import GridSampler

from ...utils import TorchioTestCase


class TestInference(TorchioTestCase):
    """Tests for `inference` module."""
    def test_inference_no_padding(self):
        self.try_inference(None)

    def test_inference_padding(self):
        self.try_inference(3)

    def try_inference(self, padding_mode):
        for mode in ["crop", "average", "hann"]:
            for n in 17, 27:
                patch_size = 10, 15, n
                patch_overlap = 0, 0, 0 # <------------- this is important and different from the usual test
                batch_size = 6

                grid_sampler = GridSampler(
                    self.sample_subject,
                    patch_size,
                    patch_overlap,
                    padding_mode=padding_mode,
                )
                aggregator = GridAggregator(grid_sampler, overlap_mode=mode)
                patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
                for patches_batch in patch_loader:
                    input_tensor = patches_batch['t1'][DATA]
                    locations = patches_batch[LOCATION]
                    logits = model(input_tensor)  # some model
                    outputs = logits
                    # 
                    i_ini, j_ini, k_ini = 1, 1, 1
                    i_fin, j_fin, k_fin = patch_size[0]-1, patch_size[1]-1, patch_size[2]-1
                    outputs = outputs[
                        :,
                        :,
                        i_ini:i_fin,
                        j_ini:j_fin,
                        k_ini:k_fin,
                    ]
                    aggregator.add_batch(outputs, locations)

                output = aggregator.get_output_tensor()
                assert (output == -5).all()
                assert output.shape == self.sample_subject.t1.shape


def model(tensor):
    tensor[:] = -5
    return tensor

Actual outcome

This raises a RuntimeError if patch_overlap is smaller than the difference between input and output, and the overlap mode is anything but crop

Below is the output of running pytest tests/data/inference/test_inference_smaller.py

Error messages

==================================================================================================== FAILURES =====================================================================================================
_____________________________________________________________________________________ TestInference.test_inference_no_padding _____________________________________________________________________________________

self = <tests.data.inference.test_inference_smaller.TestInference testMethod=test_inference_no_padding>

    def test_inference_no_padding(self):
>       self.try_inference(None)

test_inference_smaller.py:13: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test_inference_smaller.py:47: in try_inference
    aggregator.add_batch(outputs, locations)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <torchio.data.inference.aggregator.GridAggregator object at 0x7f8353643bb0>
batch_tensor = tensor([[[[[-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5..., -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5]]]]], dtype=torch.int32)
locations = array([[ 0,  0,  0, 10, 15, 17],
       [ 0,  0, 13, 10, 15, 30],
       [ 0,  5,  0, 10, 20, 17],
       [ 0,  5, 13, 10, 20, 30]])

    def add_batch(
            self,
            batch_tensor: torch.Tensor,
            locations: torch.Tensor,
    ) -> None:
        """Add batch processed by a CNN to the output prediction volume.
    
        Args:
            batch_tensor: 5D tensor, typically the output of a convolutional
                neural network, e.g. ``batch['image'][torchio.DATA]``.
            locations: 2D tensor with shape :math:`(B, 6)` representing the
                patch indices in the original image. They are typically
                extracted using ``batch[torchio.LOCATION]``.
        """
        batch = batch_tensor.cpu()
        locations = locations.cpu().numpy()
        patch_sizes = locations[:, 3:] - locations[:, :3]
        # There should be only one patch size
        assert len(np.unique(patch_sizes, axis=0)) == 1
        input_spatial_shape = tuple(batch.shape[-3:])
        target_spatial_shape = tuple(patch_sizes[0])
        if input_spatial_shape != target_spatial_shape:
            message = (
                f'The shape of the input batch, {input_spatial_shape},'
                ' does not match the shape of the target location,'
                f' which is {target_spatial_shape}'
            )
>           raise RuntimeError(message)
E           RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)

../../../src/torchio/data/inference/aggregator.py:153: RuntimeError
______________________________________________________________________________________ TestInference.test_inference_padding _______________________________________________________________________________________

self = <tests.data.inference.test_inference_smaller.TestInference testMethod=test_inference_padding>

    def test_inference_padding(self):
>       self.try_inference(3)

test_inference_smaller.py:16: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test_inference_smaller.py:47: in try_inference
    aggregator.add_batch(outputs, locations)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <torchio.data.inference.aggregator.GridAggregator object at 0x7f835149ca90>
batch_tensor = tensor([[[[[-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5..., -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5],
           [-5, -5, -5,  ..., -5, -5, -5]]]]], dtype=torch.int32)
locations = array([[ 0,  0,  0, 10, 15, 17],
       [ 0,  0, 13, 10, 15, 30],
       [ 0,  5,  0, 10, 20, 17],
       [ 0,  5, 13, 10, 20, 30]])

    def add_batch(
            self,
            batch_tensor: torch.Tensor,
            locations: torch.Tensor,
    ) -> None:
        """Add batch processed by a CNN to the output prediction volume.
    
        Args:
            batch_tensor: 5D tensor, typically the output of a convolutional
                neural network, e.g. ``batch['image'][torchio.DATA]``.
            locations: 2D tensor with shape :math:`(B, 6)` representing the
                patch indices in the original image. They are typically
                extracted using ``batch[torchio.LOCATION]``.
        """
        batch = batch_tensor.cpu()
        locations = locations.cpu().numpy()
        patch_sizes = locations[:, 3:] - locations[:, :3]
        # There should be only one patch size
        assert len(np.unique(patch_sizes, axis=0)) == 1
        input_spatial_shape = tuple(batch.shape[-3:])
        target_spatial_shape = tuple(patch_sizes[0])
        if input_spatial_shape != target_spatial_shape:
            message = (
                f'The shape of the input batch, {input_spatial_shape},'
                ' does not match the shape of the target location,'
                f' which is {target_spatial_shape}'
            )
>           raise RuntimeError(message)
E           RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)

../../../src/torchio/data/inference/aggregator.py:153: RuntimeError
================================================================================================ warnings summary =================================================================================================
test_inference_smaller.py: 16 warnings
  /home/wahab/miniconda3/envs/torchioenv/lib/python3.10/site-packages/SimpleITK/extra.py:183: DeprecationWarning: Converting `np.character` to a dtype is deprecated. The current result is `np.dtype(np.str_)` which is not strictly correct. Note that `np.character` is generally deprecated and 'S1' should be used.
    _np_sitk = {np.dtype(np.character): sitkUInt8,

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================================= short test summary info =============================================================================================
FAILED test_inference_smaller.py::TestInference::test_inference_no_padding - RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)
FAILED test_inference_smaller.py::TestInference::test_inference_padding - RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)
========================================================================================= 2 failed, 16 warnings in 0.97s =========================================================================================

Expected outcome

I believe tio should be able to handle smaller outputs. My model predictions are terrible even with averaging or hann windowing. Unfortunately most popular model libraries (such as the great monai) only provide models with the same output size and input. But it is crucial in my application to let the model see a bigger input ROI than semantic label outputs - by padding convolutions, as this gives context for the prediction. The original unet paper uses padded convolutions for smaller outputs than inputs.

I am going to make a PR tomorrow to add a fix for this, my planned changes are to only change the aggregator. This can be fixed with only changes to GridAggregator and the sampler can be left the same :

  • [x] Check if the aggregator input is smaller than the sampler output in `GridAggregator.add_batch()' before comparing it to the location patch size
  • [x] Create a variable in aggregator called patch_diffs which is the difference between input_spatial_shape and target_spatial_shape
  • [x] Change each dimension of self.patch_overlap to patch_diffs if it is smaller
  • [ ] ~Edit each location before cropping by adding half the diffs from i_ini etc and removing half the diffs from i_fin~
  • [x] Write a new unit test (Let me know if this can be improved)

If you see an issue with this happening behind the scenes, should model_output_size be added as an argument to GridAggreator or GridSampler? Or should Aggregator raise a warning if it detects it behind the scenes?

This is a bit confusing even in the code as the models output is the aggregators input, I've tried to be clear here, let me know if I havent.

System info

Platform:   Linux-5.4.0-131-generic-x86_64-with-glibc2.27
TorchIO:    0.18.86
PyTorch:    1.13.0+cu117
SimpleITK:  2.2.0 (ITK 5.3)
NumPy:      1.23.4
Python:     3.10.8 (main, Nov  4 2022, 13:48:29) [GCC 11.2.0]

wahabk avatar Nov 19 '22 17:11 wahabk

After some experimentation I realised that this can't be done with changes to the aggregator alone.

The model output size must be communicated to GridSampler beforehand, because this requires padding the input_tensor.

My immediate fix for this is to use torch padding

patch_tensor = torch.nn.functional.pad(self.image, ( 
self.patch_diffs[0], self.patch_diffs[0], 
self.patch_diffs[1], self.patch_diffs[1], 
self.patch_diffs[2], self.patch_diffs[2]), 
mode='reflect')

However, I think this could be done using some argument to tio.Subject?

Edit: ignore this as I figured out how tio uses torch padding

wahabk avatar Nov 20 '22 11:11 wahabk

Closing this stale issue (https://github.com/fepegar/torchio/pull/1002#issuecomment-2094494726).

fepegar avatar May 04 '24 22:05 fepegar