Error in aggregate over Image2DModel by Labels2DModel when transform is more complex than translation
Bug description
I stumbled upon an error when trying to aggregate an Image2DModel by a Labels2DModel in the case one of the two has a transform attached which is more complex than a translation.
Reproducer
import numpy as np
from spatialdata.models import Image2DModel, Labels2DModel
from spatialdata import SpatialData, transform, aggregate
from spatialdata.transformations import Identity, Translation, Scale
sdata = SpatialData()
sdata['image'] = Image2DModel.parse(
np.ones((1, 10, 10)),
dims=('c','y', 'x'),
transformations={
# 'global': Scale([0.9] * 2, axes=['y', 'x']), # doesn't work with aggregate
# 'global': Translation([1] * 2, axes=['y', 'x']), # works with aggregate
'global': Identity(), # works with aggregate
}
)
labels_array = np.zeros((10, 10)).astype('int')
labels_array[5:8, 5:8] = 1 # Create a
sdata['labels'] = Labels2DModel.parse(
labels_array,
dims=('y', 'x'),
transformations={
# 'global': Translation([6] * 2, axes=['y', 'x']),
'global': Identity(),
}
)
sdata.aggregate(
values='image',
by='labels',
target_coordinate_system='global'
)
ValueError: input arrays must have equal shapes
I just realised that independently of the transforms, different shapes of the image and labels element will also fail:
import numpy as np
from spatialdata.models import Image2DModel, Labels2DModel
from spatialdata import SpatialData, transform, aggregate
from spatialdata.transformations import Identity, Translation, Scale
sdata = SpatialData()
sdata['image'] = Image2DModel.parse(
np.ones((1, 20, 20)),
dims=('c','y', 'x'),
transformations={
'global': Identity(), # works with aggregate
}
)
labels_array = np.zeros((10, 10)).astype('int')
labels_array[5:8, 5:8] = 1 # Create a
sdata['labels'] = Labels2DModel.parse(
labels_array,
dims=('y', 'x'),
transformations={
'global': Identity(),
}
)
sdata.aggregate(
values='image',
by='labels',
target_coordinate_system='global'
)
Quick diagnosis attempt:
In aggregate, the image and labels element are transformed into the desired target coordinate system: https://github.com/scverse/spatialdata/blob/7604a3d2325079293ff523c5dff4483dffc890cb/src/spatialdata/_core/operations/aggregate.py#L164-L166
However, in the case of the combination of image and label elements, the resulting transformed images can have different shapes. Two possible causes aligning with the two cases reported above:
- the image and labels element occupy different physical extents (this could be worked around by first performing a spatial query on the elements)
- (floating point) differences in the calculation of the output shape (original issue reproducer).
I guess what could fix this would be if transform() for raster elements would take the desired output origin and shape as arguments. aggregate could then calculate the required values once for the "values" element and pass these to the transform call of both elements, to continue working with precisely aligned raster elements.