improve data loader performance
so I've been wanting to take another look at this for a long time, I used https://github.com/benfred/py-spy with speedscope format, you can see screenshot below.
I've been doing this on the xenium_rep_1 dataset from the paper, and been using the following code (adapting from @LucaMarconato code ):
Details
import json
import numpy as np
import pandas as pd
import torchvision.transforms.v2 as T
from spatialdata.dataloader.datasets import ImageTilesDataset
from spatialdata.transformations import Scale, get_transformation
from spatialdata.transformations import Sequence as SequenceTransformation
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path
import spatialdata as sd
xeniumrep1 = Path(
"/path/to/xenium_rep1_data_aligned.zarr"
)
sdata1 = sd.read_zarr(xeniumrep1)
visium = Path(
"/path/to/visium_data_aligned.zarr"
)
sdata3 = sd.read_zarr(visium)
TILE_SCALE = 10.0
REGION = "xeniumrep1"
sdata = sdata1
sdata.images["hne"] = sdata3.images["CytAssist_FFPE_Human_Breast_Cancer_full_image"]
def get_ds(sdata: sd.SpatialData):
img_size = 224
transform_tv = T.Compose(
[
T.ToImage(),
T.Resize((img_size, img_size), antialias=True, interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
]
)
def transform(output):
image, anno = output
instance_id, celltype = anno[:, 0].squeeze(), anno[:, 1].squeeze()
image = transform_tv(image.data.transpose(1, 2, 0).compute(scheduler="single-threaded"))
out = {"img": image, "instance_id": instance_id.tolist(), "celltype": celltype.tolist()}
return out
mu = sdata.shapes["cell_circles"]["radius"].mean()
std = sdata.shapes["cell_circles"]["radius"].std()
# large radius to cover most of the cells
large_radius = mu + 2 * std
neighbors_contex = large_radius
sdata.shapes["cell_circles"]["radius"] = neighbors_contex
instance_key = sdata.tables["table"].uns["spatialdata_attrs"]["instance_key"]
ds = ImageTilesDataset(
sdata=sdata,
regions_to_images={"cell_circles": "hne"},
regions_to_coordinate_systems={"cell_circles": "aligned"},
return_annotations=[instance_key, "celltype_major"],
tile_scale=TILE_SCALE,
transform=transform,
table_name="table",
)
return ds
ds = get_ds(sdata)
dl = DataLoader(
ds,
batch_size=256,
num_workers=0,
shuffle=False,
)
this made me realize that, if we want to return the array, than there is an unnecessary step of instantiating the SpatialImage|MultiscaleSpatialImage that is not necessary, and the dask array could be simply returned. This halved the fetch step (across 6 iterations) from ~43s to ~23s total, see below
I think the fetch step is what ultimately we want to improve, as it's the one that stream the tiles from the zarr array to the GPU. Now the two main blocks are the transform call and the compute call. The transform call is visualized under compute but it's effectively the wrapper call, where all the DataArray.isel happen, which is where the crops are defined, transformed and set, before the computation is actually triggered with compute.
I wonder what could be the next step here to chase performance gain: I think one option would be to basically "prepare" the transformation before on the full array, and then trigger it only at the tile creation in the
compute (whereas now, transformation and tile creation is done jointly for each tile). This I think would require significant refactoring though so I wonder if it makes sense at all, and if anyone has other ideas to explore @scverse/spatialdata
Codecov Report
Attention: Patch coverage is 71.42857% with 2 lines in your changes are missing coverage. Please review.
Project coverage is 92.52%. Comparing base (
8d902d4) to head (7adc03f). Report is 8 commits behind head on main.
:exclamation: Current head 7adc03f differs from pull request most recent head 7feb03b
Please upload reports for the commit 7feb03b to get more accurate results.
Additional details and impacted files
@@ Coverage Diff @@
## main #565 +/- ##
==========================================
- Coverage 92.53% 92.52% -0.02%
==========================================
Files 43 42 -1
Lines 6003 6008 +5
==========================================
+ Hits 5555 5559 +4
- Misses 448 449 +1
| Files | Coverage Δ | |
|---|---|---|
| src/spatialdata/dataloader/datasets.py | 90.73% <100.00%> (+0.04%) |
:arrow_up: |
| src/spatialdata/_core/query/spatial_query.py | 94.67% <50.00%> (-0.51%) |
:arrow_down: |
Super cool analysis! I'll also try it out (which commands did you use to open py-spy? Or did you set it up to be integrated with your IDE?)
If most of the time is spent outside dask_image.ndinterp.affine_transform() (the core function used in transform()), then I think that preparing everything before and calling affine_transform() at the end would be a good approach.
But my bet (I need to check by running the profiler), is that the problem is that we load multiple times the same chunks. I think that maybe using .persist() to automatically cache some Dask chunks, and to order the cells so that we randomize the chunks first, and then the cells inside a chunk, would lead to performance improvements.
This second approach has the advantage that it involves only the dataloader class and does not require changes in the transformation code.
I reviewed the code, looks good to me. We could merge this already or explore first the .persist() approach above in this PR.
I'll also try it out (which commands did you use to open py-spy? Or did you set it up to be integrated with your IDE?)
I've just changed the format in py-spy
py-spy record --format speedscope -o profile.speedscope.json -- python process_xenium.py
this was just a push to get the code in another machine. But let me explain what's next.
I've realized that the calculation of the transformed bounding box in the implicit coordinate system takes a fair amount of time and it could in fact be done only in the same way the tile coords dataset is built. I will therefore:
- move out the transformation from the bounding box query and do it only once at init.
- Enable to return gexp data from different layers.
The dataset will have only type of output which will be dictionary of the following
{
"tile":tile,
"annotations":list of annotations,
"gexp": list of gexp,
}
wdyt?
What I won't do here but would be useful to work on next is:
- enable to work with multiple tables (doing something similar to https://docs.lamin.ai/lamindb.core.mappedcollection )
- enable to encode labels
Thanks for the explanation. Yes, I think that operating on the transformation at the preprocessing stage is a good approach to improve performance. Also, the option to specify the layer will be useful.
Regarding the return type, would you remove the SpatialData return type or still leave it as an option?
Regarding the return type, would you remove the SpatialData return type or still leave it as an option?
that's a good question, I would potentially leave it but then technically the dataloader would fail as the default collate_fn only accepts array/mapping[str, array]/list[array], wdyt?
Ok, then I would probably move the default away from returning SpatialData (but still leave it as an option to the users). I think a good default would be one compatible with the default collate_fn.
close in favour of #687