spatialdata
spatialdata copied to clipboard
Shape mismatch in dataloader when rasterization is not used.
Originally observed by @ilia-kats
When constructing a ImageTileDataset without using rasterization, a shape mismatch bug could occur (the returned tiles are not all of of the same size). This bug is for instance triggered when trying to construct a dataloader, as the collate function would fail.
Here below is the code to reproduce, on one of the dataset from https://github.com/giovp/spatialdata-sandbox. Uncommenting the 2 lines about rasterization show that the bug only affects the query in the intrinsic coordinate system (="pixel space").
import spatialdata as sd
import torch
sdata = sd.read_zarr("/Users/macbook/embl/projects/basel/spatialdata-sandbox/visium_2.1.0_1_io/data.zarr")
IMAGE_ELEMENT = "CytAssist_FFPE_Human_Colon_Post_Xenium_Rep1_hires_image"
SHAPES_ELEMENT = "CytAssist_FFPE_Human_Colon_Post_Xenium_Rep1"
COORDINATE_SYSTEM = "downscaled_hires"
# sdata = sd.read_zarr("/Users/macbook/Desktop/mousebrain.zarr")
# IMAGE_ELEMENT = "mousebrain_hires_image"
# SHAPES_ELEMENT = "mousebrain"
# COORDINATE_SYSTEM = "downscaled_hires"
dataset = sd.dataloader.ImageTilesDataset(
sdata,
{SHAPES_ELEMENT: IMAGE_ELEMENT},
{SHAPES_ELEMENT: COORDINATE_SYSTEM},
tile_scale=1.5,
# rasterize=True,
# rasterize_kwargs={"target_width": 224},
)
dloader = torch.utils.data.DataLoader(
[tile.images[IMAGE_ELEMENT].to_numpy() for tile in dataset],
batch_size=sdata.shapes[SHAPES_ELEMENT].shape[0],
)
data = next(iter(dloader))
print(data.shape)