Concurrency-save writing to SpatialData
Zarr has the nice capability of supporting parallel read and write. That means we can have workflows with parallel processing that write data to different parts of the same Zarr dataset.
Since SpatialData is based on Zarr as backend, it would be nice to preserve that capability.
Use cases
- As a developer, I can parallelize workflows which each process a separate omics dataset and write them to a common SpatialData destination, skipping a final concatenation of many SpatialData objects. By that, I avoid writing the same data twice (one per-dataset SpatialData and later one concatenated SpatialData)
- As a developer, I can parallelize incremental construction of a SpatialData of a single omics dataset (parallel segmentation of multiple images to labels; and writing multiple annotation dataframes to the same region).
Current status
As far as I experienced, current writing fails due to a couple of reasons:
- Table does not support incremental, independent changes to different parts of the table. Instead a complete deletion is required before setting an updated table. This means,
table.uns["spatialdata_attrs"]is always deleted (thus missing, invalid) and written again. Since the table operations are not atomic, there are race conditions when other processes try to set the table at the same time. Here, we need an atomic table update function. - Adding elements can also cause race conditions (even if they seem to be independent elements)
- The in-memory datastructures can get out of sync. Especially when reading a SpatialData object, its elements dictionary is kept in memory. If second process adds an element and writes it to disk, the first process is not notified and the in-memory representation does not see the new element. For this reason, in memory or cached datastructures should be minimized so that every read access reads directly from disk (as Zarr does).
- Probably, modifying image slices of file-backed Dask arrays will flawlessly write-through to disk. This is something I have not tested.
Requirements
In my use case, I need:
- Parallely adding (independent) labels elements from different processes
- Parallely adding annotations to different regions (e.g. independent observation rows)
I don't need:
- Parallely modifying array slices of the same image element
- Parallely modifying columns of the same observation
- Parallely modifying array data and transformations of the same element from different processes
Test cases
The following is a set of test cases that compares concurrent modification for
- SpatialData elements (
sdata.add_image), SpatialData table and plain Zarr arrays - and each for in-memory SpatialData objects and backed datasets
- for multi-threading, multi-processing and the sequential case as baseline
Currently, the sequential case works for all (as expected) and Zarr supports all these concurrency cases.
Pytest test cases
from concurrent.futures import (
Executor,
Future,
ProcessPoolExecutor,
ThreadPoolExecutor,
as_completed,
)
from pathlib import Path
from typing import Union
import numpy as np
import pandas as pd
import pytest
import spatialdata
import zarr
from anndata import AnnData
from numpy.random import default_rng
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, TableModel
INSTANCE_KEY = "instance_id"
REGION_KEY = "region"
RNG = default_rng(seed=0)
class SequentialExecutor(Executor):
def __init__(self, *args, **kwargs):
pass
def submit(self, fn, /, *args, **kwargs):
result = fn(*args, **kwargs)
future = Future()
future.set_result(result)
return future
def _get_table(
region: Union[str, list[str]] = "region0",
region_key: str = REGION_KEY,
instance_key: str = INSTANCE_KEY,
n_obs: int = 100,
n_vars: int = 10,
) -> AnnData:
adata = AnnData(
RNG.normal(size=(n_obs, n_vars)),
obs=pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"]),
)
adata.obs[instance_key] = np.arange(adata.n_obs)
if isinstance(region, str):
adata.obs[region_key] = region
elif isinstance(region, list):
adata.obs[region_key] = RNG.choice(region, size=adata.n_obs)
return TableModel.parse(
adata=adata, region=region, region_key=region_key, instance_key=instance_key
)
def _create_sdata(n_obs: int = 100, n_vars: int = 10, region: str = "region0") -> SpatialData:
return SpatialData(table=_get_table(region=region, n_obs=n_obs, n_vars=n_vars))
def _update_table(sdata: SpatialData, new_table: AnnData):
"""
Place-holder for a SpatialData function that updates the table in place.
Currently, the table can (almost) only be set once, and needs to be
deleted and completely rewritten.
"""
if sdata.table is None:
return new_table
updated_table = AnnData.concatenate(sdata.table, new_table, join="outer")
if sdata.table is not None:
del sdata.table
sdata.table = TableModel.parse(
updated_table,
region_key=REGION_KEY,
instance_key=INSTANCE_KEY,
region=list(updated_table.obs[REGION_KEY].unique()),
)
def process_that_updates_table(
sdata: SpatialData, process_num: int, n_obs: int = 100, n_vars: int = 10
):
region = f"region{process_num}"
new_table_data = _get_table(region=region, n_obs=n_obs, n_vars=n_vars)
_update_table(sdata, new_table_data)
@pytest.mark.parametrize(
"executor_cls",
[
ProcessPoolExecutor,
ThreadPoolExecutor,
SequentialExecutor,
],
)
def test_concurrently_write_to_table(
executor_cls: type[Executor],
tmp_path: Path,
n_obs: int = 100,
n_vars: int = 10,
n_multi: int = 10,
):
sdata = _create_sdata(region="region0", n_obs=n_obs, n_vars=n_vars)
with executor_cls(max_workers=min(8, n_multi)) as executor:
futures = [
executor.submit(
process_that_updates_table, sdata, process_num, n_obs=n_obs, n_vars=n_vars
)
for process_num in range(1, n_multi + 1)
]
[f.result() for f in as_completed(futures, timeout=5)]
actual_sdata = sdata
actual_regions = set(actual_sdata.table.obs[REGION_KEY].unique())
expected_regions = {f"region{n}" for n in range(0, n_multi + 1)}
assert actual_regions == expected_regions
expected_n_obs = n_obs * (1 + n_multi)
assert actual_sdata.table.n_obs == expected_n_obs
def process_that_updates_table_to_backed_spatialdata(
path: Path, process_num: int, n_obs: int = 100, n_vars: int = 10
):
sdata = spatialdata.read_zarr(path)
region = f"region{process_num}"
new_table_data = _get_table(region=region, n_obs=n_obs, n_vars=n_vars)
_update_table(sdata, new_table_data)
@pytest.mark.parametrize(
"executor_cls",
[
ProcessPoolExecutor,
ThreadPoolExecutor,
SequentialExecutor,
],
)
def test_concurrently_write_to_table_to_backed_spatialdata(
executor_cls: type[Executor],
tmp_path: Path,
n_obs: int = 100,
n_vars: int = 100,
n_multi: int = 10,
):
spatialdata_path = tmp_path / "spatialdata.zarr"
sdata = _create_sdata(n_obs=n_obs, n_vars=n_vars, region="region0")
sdata.write(spatialdata_path)
with executor_cls(max_workers=min(8, n_multi)) as executor:
futures = [
executor.submit(
process_that_updates_table_to_backed_spatialdata,
path=spatialdata_path,
process_num=process_num,
n_obs=n_obs,
n_vars=n_vars,
)
for process_num in range(1, n_multi + 1)
]
[f.result() for f in as_completed(futures, timeout=5)]
actual_sdata = spatialdata.read_zarr(spatialdata_path)
actual_regions = set(actual_sdata.table.obs[REGION_KEY].unique())
expected_regions = {f"region{n}" for n in range(0, n_multi + 1)}
assert actual_regions == expected_regions
expected_n_obs = n_obs * (1 + n_multi)
assert actual_sdata.table.n_obs == expected_n_obs
def process_that_adds_image(sdata: SpatialData, process_num: int, width: int = 100):
element_name = f"image{process_num}"
image = Image2DModel.parse(
np.full(shape=(1, width, width), fill_value=process_num), dims=("c", "y", "x")
)
sdata.add_image(name=element_name, image=image)
@pytest.mark.parametrize(
"executor_cls",
[
ProcessPoolExecutor,
ThreadPoolExecutor,
SequentialExecutor,
],
)
def test_concurrently_add_element(
executor_cls: type[Executor], tmp_path: Path, width: int = 10, n_multi: int = 10
):
sdata = SpatialData()
with executor_cls(max_workers=min(8, n_multi)) as executor:
futures = [
executor.submit(process_that_adds_image, sdata, process_num, width=width)
for process_num in range(1, n_multi + 1)
]
[f.result() for f in as_completed(futures, timeout=5)]
actual_sdata = sdata
actual_image_names = set(actual_sdata.images.keys())
expected_image_names = {f"image{n}" for n in range(1, n_multi + 1)}
assert actual_image_names == expected_image_names
def process_that_adds_image_to_backed_spatialdata(path: Path, process_num: int, width: int = 100):
# Note: Read access also needs to be synced by locking because a SpatialData structure
# currently being modified (written) may be inconsistent and can cause a read failure.
sdata = spatialdata.read_zarr(path)
element_name = f"image{process_num}"
image = Image2DModel.parse(
np.full(shape=(1, width, width), fill_value=process_num), dims=("c", "y", "x")
)
sdata.add_image(name=element_name, image=image)
@pytest.mark.parametrize(
"executor_cls",
[
ProcessPoolExecutor,
ThreadPoolExecutor,
SequentialExecutor,
],
)
def test_concurrently_add_element_to_backed_spatialdata(
executor_cls: type[Executor], tmp_path: Path, width: int = 10, n_multi: int = 10
):
spatialdata_path = tmp_path / "spatialdata.zarr"
sdata = SpatialData()
sdata.write(spatialdata_path)
with executor_cls(max_workers=min(8, n_multi)) as executor:
futures = [
executor.submit(
process_that_adds_image_to_backed_spatialdata,
path=spatialdata_path,
process_num=process_num,
width=width,
)
for process_num in range(1, n_multi + 1)
]
[f.result() for f in as_completed(futures, timeout=5)]
actual_sdata = spatialdata.read_zarr(spatialdata_path)
actual_image_names = set(actual_sdata.images.keys())
expected_image_names = {f"image{n}" for n in range(1, n_multi + 1)}
assert actual_image_names == expected_image_names
def process_that_adds_zarr_array(path: Path, process_num: int, width: int = 100):
group = zarr.open_group(path)
array_name = f"array{process_num}"
group[array_name] = np.full(shape=(width, width), fill_value=process_num)
@pytest.mark.parametrize(
"executor_cls", [ProcessPoolExecutor, ThreadPoolExecutor, SequentialExecutor]
)
def test_zarr_concurrently_add_array(
executor_cls, tmp_path: Path, width: int = 10, n_multi: int = 10
):
path = tmp_path / "group.zarr"
group = zarr.open_group(path)
with executor_cls(max_workers=min(8, n_multi)) as executor:
futures = [
executor.submit(
process_that_adds_zarr_array, path=path, process_num=process_num, width=width
)
for process_num in range(1, n_multi + 1)
]
[f.result() for f in as_completed(futures, timeout=5)]
actual_group = zarr.open_group(path)
actual_array_names = set(actual_group.array_keys())
expected_array_names = {f"array{n}" for n in range(1, n_multi + 1)}
assert actual_array_names == expected_array_names
Workarounds
A workaround (and possible solution) is to use locks. Since thread locks are not visible to multi-processing, and not for multiple systems (cluster nodes) accessing the same file system over the network, the most suitable seem file locks. However, the downside is very frequent writing and deletion of the lock file (but we do have write access anyways).
-
It is also important that the lock is re-entrant so that
- recursive functions can acquire the same lock again without dead-locking
- and parent functions guarded by a lock can call a child-function which is also guarded by the same lock.
-
When ever a process fails, the lock must be released (
finallyclause, already handled byfilelock). -
Read access also needs to be guarded, because a SpatialData being written to by another process can be inconsistent and invalid. Especially
table.uns["spatialdata_attrs"]can cause validation errors if it fails to be read. -
Another limitation of this implementation is that it allows only a single reader or writer at a time. Theoretically, it must only be mutually exclusive between writers or reading (by any number of readers at the same time). See readers-writer lock and readerwriterlock.
Here is an implementation using the package filelock:
from collections.abc import Generator
from contextlib import contextmanager
from multiprocessing import Lock
from pathlib import Path
from typing import Union
from filelock import FileLock
from spatialdata import SpatialData
def _refresh_spatialdata(sdata: SpatialData):
# Refresh a SpatialData instance from disk, keeping the same in-memory Python object.
import spatialdata
if sdata.is_backed():
sdata_reloaded = spatialdata.read_zarr(sdata.path)
for attr in ["_images", "_labels", "_points", "_shapes", "_table"]:
setattr(sdata, attr, getattr(sdata_reloaded, attr))
@contextmanager
def spatialdata_locked(
sdata_or_path: Union[SpatialData, Path], timeout: int = -1
) -> Generator["SpatialData", None, None]:
"""
Context manager to ensure safe SpatialData modification without race conditions
This decorator only works with file-backed SpatialData, not for fully in-memory SpatialData.
Args:
sdata_or_path: The SpatialData to lock during execution of the yield block
timeout: An optional time-out how long to wait for acquiring a lock. When the file is
locked by another process for longer than the timeout, an error is raised.
If negative, it never times out.
Yields:
The locked SpatialData object or path
"""
import spatialdata
path = sdata_or_path if isinstance(sdata_or_path, (str, Path)) else sdata_or_path.path
if path is not None:
# SpatialData is backed.
# Use a file lock.
path = Path(path)
lock = FileLock(path.with_suffix(path.suffix + ".lock"), timeout=timeout)
else:
# SpatialData is fully in-memory, no location to write a lock file to.
# Use a multiprocessing lock.
# This is not reliable.
lock = Lock()
with lock:
if isinstance(sdata_or_path, spatialdata.SpatialData) and sdata_or_path.is_backed():
sdata = sdata_or_path
# Ensure we work with the latest version of the file
# and not the passed in-memory representation.
_refresh_spatialdata(sdata)
yield sdata
else:
yield sdata_or_path
Usage:
# Note: Read access also needs to be synced by locking because a SpatialData structure
# currently being modified (written) may be inconsistent and can cause a read failure.
with spatialdata_locked("/tmp/my_sdata.zarr"):
sdata = spatialdata.read_zarr("/tmp/my_sdata.zarr")
# Do some long processing. Since we don't hold the lock, other processes are allowed to do
# their work and commit it to the SpatialData on disk.
# For example:
image = sdata.images["image1"].data.transpose("y", "x").compute()
labels_array = skimage.segmentation.watershed(image)
labels = Labels2DModel.parse(labels_array, dims=("y", "x"), transformations=get_transformation(sdata.images["image1"], get_all=True))
# If another process holds the lock, we will wait until it is released. Otherwise,
# we will go straight into adding the result, while other processes will have to wait.
with spatialdata_locked(sdata):
sdata.add_labels(name="labels1", labels=labels)
For an official implementation, it would be good to look into how Zarr does it and to get their input and experience.