spatialdata icon indicating copy to clipboard operation
spatialdata copied to clipboard

improve data loader performance

Open giovp opened this issue 1 year ago • 7 comments

so I've been wanting to take another look at this for a long time, I used https://github.com/benfred/py-spy with speedscope format, you can see screenshot below. image

I've been doing this on the xenium_rep_1 dataset from the paper, and been using the following code (adapting from @LucaMarconato code ):

Details
import json

import numpy as np
import pandas as pd
import torchvision.transforms.v2 as T
from spatialdata.dataloader.datasets import ImageTilesDataset
from spatialdata.transformations import Scale, get_transformation
from spatialdata.transformations import Sequence as SequenceTransformation
from torch.utils.data import DataLoader
from tqdm import tqdm

from pathlib import Path
import spatialdata as sd

xeniumrep1 = Path(
  "/path/to/xenium_rep1_data_aligned.zarr"
)
sdata1 = sd.read_zarr(xeniumrep1)

visium = Path(
    "/path/to/visium_data_aligned.zarr"
)
sdata3 = sd.read_zarr(visium)

TILE_SCALE = 10.0
REGION = "xeniumrep1"
sdata = sdata1
sdata.images["hne"] = sdata3.images["CytAssist_FFPE_Human_Breast_Cancer_full_image"]

def get_ds(sdata: sd.SpatialData):
    img_size = 224

    transform_tv = T.Compose(
        [
            T.ToImage(),
            T.Resize((img_size, img_size), antialias=True, interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor(),
        ]
    )

    def transform(output):
        image, anno = output
        instance_id, celltype = anno[:, 0].squeeze(), anno[:, 1].squeeze()
        image = transform_tv(image.data.transpose(1, 2, 0).compute(scheduler="single-threaded"))
        out = {"img": image, "instance_id": instance_id.tolist(), "celltype": celltype.tolist()}
        return out

    mu = sdata.shapes["cell_circles"]["radius"].mean()
    std = sdata.shapes["cell_circles"]["radius"].std()
    # large radius to cover most of the cells
    large_radius = mu + 2 * std
    neighbors_contex = large_radius
    sdata.shapes["cell_circles"]["radius"] = neighbors_contex
    instance_key = sdata.tables["table"].uns["spatialdata_attrs"]["instance_key"]

    ds = ImageTilesDataset(
        sdata=sdata,
        regions_to_images={"cell_circles": "hne"},
        regions_to_coordinate_systems={"cell_circles": "aligned"},
        return_annotations=[instance_key, "celltype_major"],
        tile_scale=TILE_SCALE,
        transform=transform,
        table_name="table",
    )
    return ds

ds = get_ds(sdata)
dl = DataLoader(
    ds,
    batch_size=256,
    num_workers=0,
    shuffle=False,
)

this made me realize that, if we want to return the array, than there is an unnecessary step of instantiating the SpatialImage|MultiscaleSpatialImage that is not necessary, and the dask array could be simply returned. This halved the fetch step (across 6 iterations) from ~43s to ~23s total, see below image

I think the fetch step is what ultimately we want to improve, as it's the one that stream the tiles from the zarr array to the GPU. Now the two main blocks are the transform call and the compute call. The transform call is visualized under compute but it's effectively the wrapper call, where all the DataArray.isel happen, which is where the crops are defined, transformed and set, before the computation is actually triggered with compute. image I wonder what could be the next step here to chase performance gain: I think one option would be to basically "prepare" the transformation before on the full array, and then trigger it only at the tile creation in the compute (whereas now, transformation and tile creation is done jointly for each tile). This I think would require significant refactoring though so I wonder if it makes sense at all, and if anyone has other ideas to explore @scverse/spatialdata

giovp avatar May 24 '24 09:05 giovp

Codecov Report

Attention: Patch coverage is 71.42857% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 92.52%. Comparing base (8d902d4) to head (7adc03f). Report is 8 commits behind head on main.

:exclamation: Current head 7adc03f differs from pull request most recent head 7feb03b

Please upload reports for the commit 7feb03b to get more accurate results.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #565      +/-   ##
==========================================
- Coverage   92.53%   92.52%   -0.02%     
==========================================
  Files          43       42       -1     
  Lines        6003     6008       +5     
==========================================
+ Hits         5555     5559       +4     
- Misses        448      449       +1     
Files Coverage Δ
src/spatialdata/dataloader/datasets.py 90.73% <100.00%> (+0.04%) :arrow_up:
src/spatialdata/_core/query/spatial_query.py 94.67% <50.00%> (-0.51%) :arrow_down:

... and 6 files with indirect coverage changes

codecov[bot] avatar May 24 '24 09:05 codecov[bot]

Super cool analysis! I'll also try it out (which commands did you use to open py-spy? Or did you set it up to be integrated with your IDE?)

If most of the time is spent outside dask_image.ndinterp.affine_transform() (the core function used in transform()), then I think that preparing everything before and calling affine_transform() at the end would be a good approach.

But my bet (I need to check by running the profiler), is that the problem is that we load multiple times the same chunks. I think that maybe using .persist() to automatically cache some Dask chunks, and to order the cells so that we randomize the chunks first, and then the cells inside a chunk, would lead to performance improvements.

This second approach has the advantage that it involves only the dataloader class and does not require changes in the transformation code.

LucaMarconato avatar May 25 '24 13:05 LucaMarconato

I reviewed the code, looks good to me. We could merge this already or explore first the .persist() approach above in this PR.

LucaMarconato avatar May 25 '24 13:05 LucaMarconato

I'll also try it out (which commands did you use to open py-spy? Or did you set it up to be integrated with your IDE?)

I've just changed the format in py-spy py-spy record --format speedscope -o profile.speedscope.json -- python process_xenium.py

this was just a push to get the code in another machine. But let me explain what's next.

I've realized that the calculation of the transformed bounding box in the implicit coordinate system takes a fair amount of time and it could in fact be done only in the same way the tile coords dataset is built. I will therefore:

  • move out the transformation from the bounding box query and do it only once at init.
  • Enable to return gexp data from different layers.

The dataset will have only type of output which will be dictionary of the following

{
	"tile":tile,
	"annotations":list of annotations,
	"gexp": list of gexp,
}

wdyt?

What I won't do here but would be useful to work on next is:

  • enable to work with multiple tables (doing something similar to https://docs.lamin.ai/lamindb.core.mappedcollection )
  • enable to encode labels

giovp avatar May 28 '24 11:05 giovp

Thanks for the explanation. Yes, I think that operating on the transformation at the preprocessing stage is a good approach to improve performance. Also, the option to specify the layer will be useful.

Regarding the return type, would you remove the SpatialData return type or still leave it as an option?

LucaMarconato avatar May 28 '24 12:05 LucaMarconato

Regarding the return type, would you remove the SpatialData return type or still leave it as an option?

that's a good question, I would potentially leave it but then technically the dataloader would fail as the default collate_fn only accepts array/mapping[str, array]/list[array], wdyt?

giovp avatar May 28 '24 13:05 giovp

Ok, then I would probably move the default away from returning SpatialData (but still leave it as an option to the users). I think a good default would be one compatible with the default collate_fn.

LucaMarconato avatar May 28 '24 15:05 LucaMarconato

close in favour of #687

giovp avatar Sep 03 '24 18:09 giovp