TileDB-Py icon indicating copy to clipboard operation
TileDB-Py copied to clipboard

Indexing into dense `tiledb.Array` with arbitrary `np.ndarray` indices

Open hanslovsky opened this issue 11 months ago • 0 comments

Summary

I am sharing my experience and looking for recommendations for indexing into dense tiledb.Arrays with arbitrary integer indices (np.ndarray).

Background

I am considering tiledb as a unified data backend for a pipeline that can work either on streaming data from disk or in-memory data (mem://). Each step operates on an immutable input and creates a new (collection of) 2D tiledb.Array. Some of our operations select rows/columns based on some (possibly cheap) criterion. To avoid unnecessary creation of arrays, I played around with wrapping the tiledb.Arrays in a view class with an optional selection (either a slice or np.ndarray). My hope was to minimize creation of new tiledb.Array instances and instead utilize those views as much as possible and defer calling tiledb.Array.__getitem__ or tiledb.Array.multi_index until I actually work with data (instead of just subsetting). Doing this, I ran into a scenario where getting the data out of tiledb.Arrays became extremely slow. I was able to reproduce this with the following example:

  1. Create 10k by 10k np.ndarray.
  2. Copy into a mem:// tiledb.Array
  3. Access array via multi_index[np.arange(10k), np.arange(10k)] (fast)
  4. Access array via multi_index[np.arange(10k)[::2], np.arange(10k)[::2]] (slow)

Minimal Working Example

I have a minimal working example with requirements.txt here

This is the Python script:

import json
import time
from uuid import uuid4

import numpy as np
import tiledb
import tiledb as tdb


def select(arr: tdb.DenseArray, sl1, sl2) -> np.ndarray:
    t0 = time.time()
    match sl1:
        case (slice(), slice()):
            preselected = arr[sl1]
        case _:
            preselected = arr.multi_index[sl1][""]

    ndarr = preselected[sl2]
    t1 = time.time()
    dt = t1 - t0
    print(f"{int(1000*dt)}ms")
    return ndarr


def make_memory_array(arr: np.ndarray, ctx: tiledb.Ctx | None = None) -> tiledb.DenseArray:
    if ctx is None:
        ctx = tiledb.Ctx()
    dims = [tiledb.Dim(f"{d}", domain=(0, s-1)) for d, s in enumerate(arr.shape)]
    domain = tiledb.Domain(*dims, ctx=ctx)
    schema = tiledb.ArraySchema(domain=domain, attrs=[tiledb.Attr(name="", dtype=arr.dtype)], ctx=ctx, sparse=False)
    uri = f"mem://{uuid4()}"
    tiledb.Array.create(uri, ctx=ctx, schema=schema)
    with tiledb.open(uri, mode="w", ctx=ctx) as a:
        a[...] = arr
    return tiledb.open(uri, mode="r", ctx=ctx)


if __name__ == "__main__":
    n = 10_000
    data = np.arange(n*n).reshape((n, n))
    print(f"{data.dtype=} {data.size=}")
    cont = np.arange(n)
    ctx = tiledb.Ctx()

    # This is the slow part for non-contiguous:
    #     self.pyquery.submit()
    # https://github.com/TileDB-Inc/TileDB-Py/blob/f0267e6f4fac0e00af66c012e23b2c95b95c75fc/tiledb/multirange_indexing.py#L360

    tdb_arr = make_memory_array(data, ctx=ctx)
    assert np.all(tdb_arr[...] == data)

    for step in (None, 2):
        # tiledb.stats_reset()
        # tiledb.stats_enable()
        n = str(step)

        sl = slice(None, None, step)
        ref_sel = data[sl, sl]
        def _validate(selected):
            assert selected.shape == ref_sel.shape, f"{selected.shape} != {ref_sel.shape}"
            assert np.all(selected == ref_sel)

        prefix = f"step={n:<6}"
        print(f"{prefix:<6} ref           ", end="  ")
        sel = select(tdb_arr, (sl, sl), slice(None))
        _validate(sel)

        ind = cont[sl]
        print(f"{prefix:<6} indices       ", end="  ")
        sel = select(tdb_arr, (ind, ind), slice(None))
        _validate(sel)

        print(f"{prefix:<6} indices split ", end="  ")
        sel = select(tdb_arr, (ind, slice(None)), (slice(None), ind))
        _validate(sel)

        slices = [slice(s, s) for s in ind]
        print(f"{prefix:<6} slices        ", end="  ")
        sel = select(tdb_arr, (slices, slices), (slice(None), slice(None)))
        _validate(sel)
        
        print(f"{prefix:<6} slices split  ", end="  ")
        sel = select(tdb_arr, (slices, slice(None)), (slice(None), ind))
        _validate(sel)

        # s = tiledb.stats_dump(json=True)
        # with open(f"{n}.json", "w") as f:
        #     json.dump(json.loads(s), f, indent=2)

with example output

data.dtype=dtype('int64') data.size=100000000
step=None   ref             521ms
step=None   indices         535ms
step=None   indices split   825ms
step=None   slices          579ms
step=None   slices split    881ms
step=2      ref             523ms
step=2      indices         22383ms
step=2      indices split   1194ms
step=2      slices          16637ms
step=2      slices split    933ms

Observation:

  1. tiledb.Array.multi_index[np.ndarray, np.ndarray] can be extremely slow. This is surprising because it can be a lot faster for some cases, even though much more data is loaded
  2. I can probably get around this with tiledb.Array.multi_index[np.ndarray, slice(None)][""][:, np.ndarray]. This seems to be reasonably fast and is likely compatible with our access pattern.
  3. I figure that my example here is probably near worst case performance. I tried to debug into tiledb code as much as possible, but I was not able to follow into anything implemented in C++. I am not sure what would need to change on the tiledb.Array implementation to optimize such a use case as presented here.

hanslovsky avatar Jan 28 '25 06:01 hanslovsky