tiler icon indicating copy to clipboard operation
tiler copied to clipboard

Reworked Merger that supports multiclass

Open the-lay opened this issue 2 years ago • 0 comments

Resolves #20

I have pushed a pretty big rework of Merger and it has three new/updated keywords now (ignore_channels: bool = False, logits_n: Optional[int] = None, logits_dim: int = 0).

Here's an example of how I imagine it all can be used. It's significantly more flexible, but maybe the API is a bit too complex now.

@jordancaraballo, please take a look, what do you think? Am I missing anything in your opinion? Otherwise in the next commits I will fix tests and make sure I didn't break anything else.

import numpy as np
from tiler import Tiler, Merger


# Let's say you have an image of size 5000x3000 pixels and 4 channels in the last dimension
image_shape = (5000, 3000, 4)
image_channel_dimension = -1
# and you want to tile them into tiles of 256x256 pixels and 4 channels in the last dimension
tile_shape = (256, 256, 4)
tile_overlap = 0.5
# to feed into a segmentation network with 10 output classes (in the last dimension) and batches of 128 tiles
# (so the network output has shape of (128, 256, 256, 10))
output_classes = 10
output_classes_dim = -1
batch_size = 128

image = np.random.rand(*image_shape)
tiler = Tiler(
    data_shape=image_shape,
    tile_shape=tile_shape,
    channel_dimension=image_channel_dimension,
)
merger = Merger(
    tiler,
    ignore_channels=True,  # this allows to "turn off" channels from Tiler
    logits_n=output_classes,  # this specifies how many logits/segmentation classes there will be
    logits_dim=output_classes_dim,  # and in which dimension
)

print("Processing batches...")
for batch_id, batch in tiler(image, batch_size=batch_size):
    print(f"\tBatch: #{batch_id}, with data of shape {batch.shape}")

    # simulating network output of shape (128, 256, 256, 10)
    output = np.random.rand(batch_size, *tile_shape[:-1], output_classes)
    print(f"\tWe simulate NN output with shape of {output.shape} and add it to Merger")

    # adding output into Merger
    merger.add_batch(batch_id, batch_size, output)

print("Processing finished.")

raw_merge_result = merger.merge(argmax=None, unpad=False)
print(f"Shape of the raw merge result: {raw_merge_result.shape}")  # (5120, 3072, 10)

unpad_merge_result = merger.merge(argmax=None, unpad=True)
print(f"Shape of the unpad merge result: {unpad_merge_result.shape}")  # (5000, 3000, 10)

argmaxed_merge_result = merger.merge(argmax=output_classes_dim, unpad=True)
print(f"Shape of the argmaxed merge result: {argmaxed_merge_result.shape}")  # (5000, 3000)

the-lay avatar Jun 12 '22 19:06 the-lay