[Feature] searchsorted
Is there an equivalent to np.searchsorted or a way that I could reasonably implement something similar with the existing ops?
We don't have it, but you can certainly do a binary search with existing ops. Here's something for 1D arrays that works, it should be fairly straight-forward to support an axis parameter if you need it:
def searchsorted(a, b):
axis = 0
size = a.shape[axis]
steps = math.ceil(math.log2(size))
upper = size
lower = 0
indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32)
for _ in range(steps):
lt = b < a[indices]
new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2)
lower = mx.where(lt, lower, indices)
upper = mx.where(lt, indices, upper)
indices = new_indices
return indices
Also it will be a lot faster if you mx.compile it. Particularly if you are using it multiple times with the same shapes.
I'm open to adding a little binary search implementation like that into MLX to support searchsorted. We could use the above as a starting point.
Another option is to do something like the following. It's linear in a but probably quite a bit faster, especially for small arrays, since it's far few operations:
def searchsorted(a, b):
return (a[None, :] < b[:, None]).sum(axis=1)
Ah okay I think I can make these workarounds for time time being. Thanks!!
@awni can I start working on this issue since it was last active on july 7
By all means
Another option is to do something like the following. It's linear in
abut probably quite a bit faster, especially for small arrays, since it's far few operations:def searchsorted(a, b): return (a[None, :] < b[:, None]).sum(axis=1)
Do we have to go with this approach or with the one in numpy array ?
I would compare
def searchsorted(a, b):
axis = 0
size = a.shape[axis]
steps = math.ceil(math.log2(size))
upper = size
lower = 0
indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32)
for _ in range(steps):
lt = b < a[indices]
new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2)
lower = mx.where(lt, lower, indices)
upper = mx.where(lt, indices, upper)
indices = new_indices
return indices
and
def searchsorted(a, b):
return (a[None, :] < b[:, None]).sum(axis=1)
And see which is faster. Presumably there will be a size at which the first is faster but it will start out slower. We could try to dispatch based on that. Or just use the more scalable version.
I would compare
def searchsorted(a, b): axis = 0 size = a.shape[axis] steps = math.ceil(math.log2(size)) upper = size lower = 0 indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32) for _ in range(steps): lt = b < a[indices] new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2) lower = mx.where(lt, lower, indices) upper = mx.where(lt, indices, upper) indices = new_indices return indicesand
def searchsorted(a, b): return (a[None, :] < b[:, None]).sum(axis=1)And see which is faster. Presumably there will be a size at which the first is faster but it will start out slower. We could try to dispatch based on that. Or just use the more scalable version.
Assuming B and A are representing length of array b and a. The time complexity and space complexity of the first case is: O(b* log(a)), O(b) for 2nd case: O(ab), O(ab)
From scalability point of view (if we compare space and time complexity), I think 1st case looks more appropriate right?
The constant factors of the logarithmic approach are quite larger so it is not as simple as that. The following is on my laptop. Also note that mx.compile helps the binary search quit a bit. These are all on the GPU as well, the CPU could be faster for some searches.
Sorted size | Search size | Binary search | Binary search compiled | Linear search
------------+-------------+---------------+------------------------+----------------
1024 | 1 | 0.65 ms | 0.38 ms | 0.14 ms
1024 | 4 | 0.58 ms | 0.37 ms | 0.14 ms
1024 | 16 | 0.56 ms | 0.36 ms | 0.14 ms
1024 | 64 | 0.56 ms | 0.36 ms | 0.14 ms
1024 | 256 | 0.56 ms | 0.35 ms | 0.17 ms
1024 | 1024 | 0.57 ms | 0.35 ms | 0.22 ms
16384 | 1 | 0.76 ms | 0.41 ms | 0.21 ms
16384 | 4 | 0.75 ms | 0.42 ms | 0.14 ms
16384 | 16 | 0.74 ms | 0.45 ms | 0.16 ms
16384 | 64 | 0.75 ms | 0.43 ms | 0.22 ms
16384 | 256 | 0.74 ms | 0.42 ms | 0.41 ms
16384 | 1024 | 0.80 ms | 0.44 ms | 1.18 ms
2097152 | 1 | 1.02 ms | 0.53 ms | 0.61 ms
2097152 | 4 | 0.98 ms | 0.55 ms | 0.91 ms
2097152 | 16 | 1.00 ms | 0.55 ms | 2.11 ms
2097152 | 64 | 1.01 ms | 0.55 ms | 7.77 ms
2097152 | 256 | 1.02 ms | 0.56 ms | 34.49 ms
2097152 | 1024 | 1.03 ms | 0.57 ms | 132.58 ms
The TL;DR is that if you want to search in less than 16k elements or if you only searching 1-2 elements it doesn't make much sense in using the binary search. If otoh you are searching for a lot of elements in a large sorted array (in the millions of elements), then you can expect 100x improvement using binary search :-) .
Thank you @angeloskath for the info. I think based on the application point of view there will be rare cases involving more than 16k elements. Shall I implement linear search then? Or another idea is to implement both and use either based on condition such as sorted size.
@awni in which particular module shall we implement this feature?
Is there any new update on the searchsorted function with mlx? Or, are we still relying on
def searchsorted(a, b):
return (a[None, :] < b[:, None]).sum(axis=1)
for size<16k ?
There is no update. You can use the above for now. If nyone wants to work on the issue it's open... we'll try to get to it otherwise when we have B/W.