spatialdata icon indicating copy to clipboard operation
spatialdata copied to clipboard

Optimize writing performance for `MultiscaleSpatialImage`

Open LucaMarconato opened this issue 1 year ago • 1 comments

As observed by @ArneDefauw, unnecessary loading operations are performed when calling write_multiscale() on a list of lazy tensors derived from to_multiscale().

Optimizing the order in which the data is computed and written to disk, so to avoid the loading of the same chunks 2+ times, would probably lead to a drastic performance improvement, up to 10-fold.

LucaMarconato avatar Jun 11 '24 14:06 LucaMarconato

I include a minimal example to reproduce the observed behaviour.

If arr.persist() is called, the code completes in ~10s, but if arr.persist() is commented, the code compeletes in ~50 s (i.e. 10s for each scale -> some_function is called 5 times). As an alternative to using .persist(), writing to a zarr store, and then loading it back, evidently 'solves' the problem in a similar way.


import os
import tempfile
import time

import dask.array as da
import numpy as np
import spatialdata
from spatialdata.datasets import blobs

sdata = blobs()

start = time.time()

with tempfile.TemporaryDirectory() as temp_dir:
    sdata.write(os.path.join(temp_dir, "sdata_blobs_dummy.zarr"))

    def _some_function(arr):
        arr = arr * 2
        time.sleep(10)
        return arr

    arr = sdata["blobs_image"].data

    arr = da.map_blocks(_some_function, arr, dtype=float, meta=np.array((), dtype=float))

    arr = arr.persist()

    # or as alternative to persist, write to intermediate zarr store
    # dask_zarr_path = os.path.join(temp_dir, "dask_array.zarr")
    # arr.to_zarr(dask_zarr_path, overwrite=True)
    # arr = da.from_zarr(dask_zarr_path)

    se = spatialdata.models.Image2DModel.parse(
        arr,
        scale_factors=[2, 2, 2, 2],
    )

    sdata["blobs_image_processed"] = se

    sdata.write_element("blobs_image_processed")


print(time.time() - start)

ArneDefauw avatar Jun 12 '24 07:06 ArneDefauw

I found that Dask was called for each pyramid level separately. That means Dask was not able to use the optimal computational graph across all pyramid levels. An alternative is to pass multiple Dask arrays at once to Dask, so that it optimizes the computational graph, re-uses intermediate results and avoids recomputations.

The performance improvement should be: $$n := \textrm{number of pixels of scale level 0}$$ $$f := \textrm{number of scale factors}$$ $$\mathcal{O}(n \cdot (f+1)) \rightarrow \mathcal{O}(n)$$

With your example I could achieve the same, expected improvement of 5×.

aeisenbarth avatar Oct 05 '24 13:10 aeisenbarth

Thanks @aeisenbarth for addressing this, it's a great performance improvement.

LucaMarconato avatar Oct 09 '24 12:10 LucaMarconato