Optimize writing performance for `MultiscaleSpatialImage`
As observed by @ArneDefauw, unnecessary loading operations are performed when calling write_multiscale() on a list of lazy tensors derived from to_multiscale().
Optimizing the order in which the data is computed and written to disk, so to avoid the loading of the same chunks 2+ times, would probably lead to a drastic performance improvement, up to 10-fold.
I include a minimal example to reproduce the observed behaviour.
If arr.persist() is called, the code completes in ~10s, but if arr.persist() is commented, the code compeletes in ~50 s (i.e. 10s for each scale -> some_function is called 5 times).
As an alternative to using .persist(), writing to a zarr store, and then loading it back, evidently 'solves' the problem in a similar way.
import os
import tempfile
import time
import dask.array as da
import numpy as np
import spatialdata
from spatialdata.datasets import blobs
sdata = blobs()
start = time.time()
with tempfile.TemporaryDirectory() as temp_dir:
sdata.write(os.path.join(temp_dir, "sdata_blobs_dummy.zarr"))
def _some_function(arr):
arr = arr * 2
time.sleep(10)
return arr
arr = sdata["blobs_image"].data
arr = da.map_blocks(_some_function, arr, dtype=float, meta=np.array((), dtype=float))
arr = arr.persist()
# or as alternative to persist, write to intermediate zarr store
# dask_zarr_path = os.path.join(temp_dir, "dask_array.zarr")
# arr.to_zarr(dask_zarr_path, overwrite=True)
# arr = da.from_zarr(dask_zarr_path)
se = spatialdata.models.Image2DModel.parse(
arr,
scale_factors=[2, 2, 2, 2],
)
sdata["blobs_image_processed"] = se
sdata.write_element("blobs_image_processed")
print(time.time() - start)
I found that Dask was called for each pyramid level separately. That means Dask was not able to use the optimal computational graph across all pyramid levels. An alternative is to pass multiple Dask arrays at once to Dask, so that it optimizes the computational graph, re-uses intermediate results and avoids recomputations.
The performance improvement should be: $$n := \textrm{number of pixels of scale level 0}$$ $$f := \textrm{number of scale factors}$$ $$\mathcal{O}(n \cdot (f+1)) \rightarrow \mathcal{O}(n)$$
With your example I could achieve the same, expected improvement of 5×.
Thanks @aeisenbarth for addressing this, it's a great performance improvement.