torchio icon indicating copy to clipboard operation
torchio copied to clipboard

Add CropOrPadAtCenter transform class

Open themantalope opened this issue 1 year ago • 4 comments

🚀 Feature A crop or pad function that allows the user to crop/pad an image but specify where the center of the new image should be based on the input image. For example, let's say I have an organ centered at location [45,62,101] in a volume that is [368, 512, 128] in size and I want my new image to be [128,128,128] in size, and the center of the new image (i.e. [64,64,64]) to map to the point [45,62,101] in the original image.

Motivation

Developing datasets with tumors, need to crop and pad around the tumors.

Pitch

Described above. Proposed code below.

Alternatives

N/A

Additional context Here is my proposed code. I've tested it locally but probably need more robust testing.

import torchio as tio

class CropOrPadAtCenter(tio.Transform):
    
    def __init__(self, center , target_shape, image_or_mask_name=None, **kwargs):
        super().__init__(**kwargs)
        self.target_shape = target_shape
        self.center = center
        self.image_or_mask_name = image_or_mask_name
        

    def apply_transform(self, subject):
        # if image or mask name is None, use the first image as the base image to work with
        if self.image_or_mask_name is None:
            images = list(subject.keys())
            base_image = images[0]
            self.image_or_mask_name = base_image
        
        image = subject[self.image_or_mask_name]
        non_channel_im_shape = image.shape[-3:]
        
        # first assert that the center is in the image
        assert all([c >= 0 and c < s for c, s in zip(self.center, image.shape[-3:])])
        # next determine if we need to pad. if the bounds of the target shape are outside the image, we need to pad
        
        # compute how many pixels we need to pad for each dimension 
        pad = []
        for c, s, t in zip(self.center, non_channel_im_shape, self.target_shape):
            
            lower = 0
            upper = 0
            
            if c - t//2 < 0:
                lower = abs(c - t//2)
            
            if c + t//2 > s:
                upper = c + t//2 - s

            pad.extend([lower, upper])
        # pad the image
        p = tuple(pad)
        pad_xform = tio.Pad(p)
        subject = pad_xform(subject)

        # now crop the image
        # the crop function expects the start and dim_size - end of the crop (weird)
        image = subject[self.image_or_mask_name]
        non_channel_im_shape = image.shape[-3:]
        lower_bound_pads = [p for i, p in enumerate(pad) if i % 2 == 0]
        new_center = [c + l for c, l in zip(self.center, lower_bound_pads)]
        self.center = new_center
        crop = []
        width = self.target_shape[0]//2
        height = self.target_shape[1]//2
        depth = self.target_shape[2]//2
        im_width, im_height, im_depth = non_channel_im_shape
        for d, s, c in zip([width, height, depth], [im_width, im_height, im_depth], [self.center[0], self.center[1], self.center[2]]):
            start = c - d
            end = s - (c+d)
            crop.extend([start, end])

        ct = tuple(crop)
        crop_xform = tio.Crop(ct)
        subject = crop_xform(subject)
        return subject

        

themantalope avatar Nov 05 '24 20:11 themantalope

Hi thanks for sharing. Actually the same behavior can be achieve with torchio LabelSampler

here is a code exemple:

import torchio as tio, torch

center_patch = [132,100,170]
patch_size = 128
around_center = 0

suj = tio.datasets.Colin27()
label_proba =  torch.zeros_like(suj.t1.data)
label_proba[0,center_patch[0]-around_center:center_patch[0]+around_center+1,center_patch[1]-around_center:center_patch[1]+around_center+1,center_patch[2]-around_center:center_patch[2]+around_center +1] = 1

img = tio.LabelMap(tensor = label_proba, affine=suj.t1.affine)

suj.add_image(img, 'label_proba')

#Pad to handel worste case cenario where center is at the border of the image (so extend with patch_size/2 + around_center//2 )

t_pad = tio.Pad(patch_size//2 + around_center//2)
suj = t_pad(suj)

lab_s = tio.LabelSampler(128,'label_proba',{0:0,1:1})
generator = lab_s(suj, num_patches=10)

for patch in generator:
    locations = patch[tio.LOCATION]
    print(f'loc {locations} ')
#note here that locations are 3 indices start and 3 indice end in the padded suj indexes (not the center patch)
# so it is different from the original center_patch because we pad the volume, but the results is as expected

Does will suite your needs ?

romainVala avatar Nov 06 '24 08:11 romainVala

I see, yes that would work for what I'm looking for. May need some additional code to ensure that I'm getting all voxels in the sampled image ,but this is basically what I want.

themantalope avatar Nov 07 '24 01:11 themantalope

what do you mean by all voxel ? You will get only voxel that are less than patch_size/2 away from the center

romainVala avatar Nov 07 '24 11:11 romainVala

to ensure that the sampled patch contains all positive voxels in the label.

i guess the difference here is that i'm specifying the center so there is more control over where the sample is getting drawn from, which is important for my use case

themantalope avatar Nov 07 '24 14:11 themantalope