Indexing into dense `tiledb.Array` with arbitrary `np.ndarray` indices
Summary
I am sharing my experience and looking for recommendations for indexing into dense tiledb.Arrays with arbitrary integer indices (np.ndarray).
Background
I am considering tiledb as a unified data backend for a pipeline that can work either on streaming data from disk or in-memory data (mem://). Each step operates on an immutable input and creates a new (collection of) 2D tiledb.Array. Some of our operations select rows/columns based on some (possibly cheap) criterion. To avoid unnecessary creation of arrays, I played around with wrapping the tiledb.Arrays in a view class with an optional selection (either a slice or np.ndarray). My hope was to minimize creation of new tiledb.Array instances and instead utilize those views as much as possible and defer calling tiledb.Array.__getitem__ or tiledb.Array.multi_index until I actually work with data (instead of just subsetting). Doing this, I ran into a scenario where getting the data out of tiledb.Arrays became extremely slow. I was able to reproduce this with the following example:
- Create 10k by 10k
np.ndarray. - Copy into a
mem://tiledb.Array - Access array via
multi_index[np.arange(10k), np.arange(10k)](fast) - Access array via
multi_index[np.arange(10k)[::2], np.arange(10k)[::2]](slow)
Minimal Working Example
I have a minimal working example with requirements.txt here
This is the Python script:
import json
import time
from uuid import uuid4
import numpy as np
import tiledb
import tiledb as tdb
def select(arr: tdb.DenseArray, sl1, sl2) -> np.ndarray:
t0 = time.time()
match sl1:
case (slice(), slice()):
preselected = arr[sl1]
case _:
preselected = arr.multi_index[sl1][""]
ndarr = preselected[sl2]
t1 = time.time()
dt = t1 - t0
print(f"{int(1000*dt)}ms")
return ndarr
def make_memory_array(arr: np.ndarray, ctx: tiledb.Ctx | None = None) -> tiledb.DenseArray:
if ctx is None:
ctx = tiledb.Ctx()
dims = [tiledb.Dim(f"{d}", domain=(0, s-1)) for d, s in enumerate(arr.shape)]
domain = tiledb.Domain(*dims, ctx=ctx)
schema = tiledb.ArraySchema(domain=domain, attrs=[tiledb.Attr(name="", dtype=arr.dtype)], ctx=ctx, sparse=False)
uri = f"mem://{uuid4()}"
tiledb.Array.create(uri, ctx=ctx, schema=schema)
with tiledb.open(uri, mode="w", ctx=ctx) as a:
a[...] = arr
return tiledb.open(uri, mode="r", ctx=ctx)
if __name__ == "__main__":
n = 10_000
data = np.arange(n*n).reshape((n, n))
print(f"{data.dtype=} {data.size=}")
cont = np.arange(n)
ctx = tiledb.Ctx()
# This is the slow part for non-contiguous:
# self.pyquery.submit()
# https://github.com/TileDB-Inc/TileDB-Py/blob/f0267e6f4fac0e00af66c012e23b2c95b95c75fc/tiledb/multirange_indexing.py#L360
tdb_arr = make_memory_array(data, ctx=ctx)
assert np.all(tdb_arr[...] == data)
for step in (None, 2):
# tiledb.stats_reset()
# tiledb.stats_enable()
n = str(step)
sl = slice(None, None, step)
ref_sel = data[sl, sl]
def _validate(selected):
assert selected.shape == ref_sel.shape, f"{selected.shape} != {ref_sel.shape}"
assert np.all(selected == ref_sel)
prefix = f"step={n:<6}"
print(f"{prefix:<6} ref ", end=" ")
sel = select(tdb_arr, (sl, sl), slice(None))
_validate(sel)
ind = cont[sl]
print(f"{prefix:<6} indices ", end=" ")
sel = select(tdb_arr, (ind, ind), slice(None))
_validate(sel)
print(f"{prefix:<6} indices split ", end=" ")
sel = select(tdb_arr, (ind, slice(None)), (slice(None), ind))
_validate(sel)
slices = [slice(s, s) for s in ind]
print(f"{prefix:<6} slices ", end=" ")
sel = select(tdb_arr, (slices, slices), (slice(None), slice(None)))
_validate(sel)
print(f"{prefix:<6} slices split ", end=" ")
sel = select(tdb_arr, (slices, slice(None)), (slice(None), ind))
_validate(sel)
# s = tiledb.stats_dump(json=True)
# with open(f"{n}.json", "w") as f:
# json.dump(json.loads(s), f, indent=2)
with example output
data.dtype=dtype('int64') data.size=100000000
step=None ref 521ms
step=None indices 535ms
step=None indices split 825ms
step=None slices 579ms
step=None slices split 881ms
step=2 ref 523ms
step=2 indices 22383ms
step=2 indices split 1194ms
step=2 slices 16637ms
step=2 slices split 933ms
Observation:
tiledb.Array.multi_index[np.ndarray, np.ndarray]can be extremely slow. This is surprising because it can be a lot faster for some cases, even though much more data is loaded- I can probably get around this with
tiledb.Array.multi_index[np.ndarray, slice(None)][""][:, np.ndarray]. This seems to be reasonably fast and is likely compatible with our access pattern. - I figure that my example here is probably near worst case performance. I tried to debug into tiledb code as much as possible, but I was not able to follow into anything implemented in C++. I am not sure what would need to change on the
tiledb.Arrayimplementation to optimize such a use case as presented here.