mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Feature] searchsorted

Open weirdykid opened this issue 1 year ago • 14 comments

Is there an equivalent to np.searchsorted or a way that I could reasonably implement something similar with the existing ops?

weirdykid avatar Jul 07 '24 18:07 weirdykid

We don't have it, but you can certainly do a binary search with existing ops. Here's something for 1D arrays that works, it should be fairly straight-forward to support an axis parameter if you need it:

def searchsorted(a, b):
    axis = 0
    size = a.shape[axis]
    steps = math.ceil(math.log2(size))
    upper = size
    lower = 0
    indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32)
    for _ in range(steps):
        lt = b < a[indices]
        new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2)
        lower = mx.where(lt, lower, indices)
        upper = mx.where(lt, indices, upper)
        indices = new_indices
    return indices

Also it will be a lot faster if you mx.compile it. Particularly if you are using it multiple times with the same shapes.

awni avatar Jul 07 '24 19:07 awni

I'm open to adding a little binary search implementation like that into MLX to support searchsorted. We could use the above as a starting point.

awni avatar Jul 07 '24 19:07 awni

Another option is to do something like the following. It's linear in a but probably quite a bit faster, especially for small arrays, since it's far few operations:

def searchsorted(a, b):
    return (a[None, :] < b[:, None]).sum(axis=1)

awni avatar Jul 08 '24 00:07 awni

Ah okay I think I can make these workarounds for time time being. Thanks!!

weirdykid avatar Jul 08 '24 01:07 weirdykid

@awni can I start working on this issue since it was last active on july 7

Saanidhyavats avatar Sep 12 '24 04:09 Saanidhyavats

By all means

awni avatar Sep 12 '24 04:09 awni

Another option is to do something like the following. It's linear in a but probably quite a bit faster, especially for small arrays, since it's far few operations:

def searchsorted(a, b):
    return (a[None, :] < b[:, None]).sum(axis=1)

Do we have to go with this approach or with the one in numpy array ?

Saanidhyavats avatar Sep 18 '24 22:09 Saanidhyavats

I would compare

def searchsorted(a, b):
    axis = 0
    size = a.shape[axis]
    steps = math.ceil(math.log2(size))
    upper = size
    lower = 0
    indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32)
    for _ in range(steps):
        lt = b < a[indices]
        new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2)
        lower = mx.where(lt, lower, indices)
        upper = mx.where(lt, indices, upper)
        indices = new_indices
    return indices

and

def searchsorted(a, b):
    return (a[None, :] < b[:, None]).sum(axis=1)

And see which is faster. Presumably there will be a size at which the first is faster but it will start out slower. We could try to dispatch based on that. Or just use the more scalable version.

awni avatar Sep 19 '24 14:09 awni

I would compare

def searchsorted(a, b):
    axis = 0
    size = a.shape[axis]
    steps = math.ceil(math.log2(size))
    upper = size
    lower = 0
    indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32)
    for _ in range(steps):
        lt = b < a[indices]
        new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2)
        lower = mx.where(lt, lower, indices)
        upper = mx.where(lt, indices, upper)
        indices = new_indices
    return indices

and

def searchsorted(a, b):
    return (a[None, :] < b[:, None]).sum(axis=1)

And see which is faster. Presumably there will be a size at which the first is faster but it will start out slower. We could try to dispatch based on that. Or just use the more scalable version.

Assuming B and A are representing length of array b and a. The time complexity and space complexity of the first case is: O(b* log(a)), O(b) for 2nd case: O(ab), O(ab)

From scalability point of view (if we compare space and time complexity), I think 1st case looks more appropriate right?

Saanidhyavats avatar Sep 19 '24 18:09 Saanidhyavats

The constant factors of the logarithmic approach are quite larger so it is not as simple as that. The following is on my laptop. Also note that mx.compile helps the binary search quit a bit. These are all on the GPU as well, the CPU could be faster for some searches.

Sorted size | Search size | Binary search | Binary search compiled | Linear search
------------+-------------+---------------+------------------------+----------------
     1024   |        1    |     0.65 ms   |             0.38 ms    |      0.14 ms
     1024   |        4    |     0.58 ms   |             0.37 ms    |      0.14 ms
     1024   |       16    |     0.56 ms   |             0.36 ms    |      0.14 ms
     1024   |       64    |     0.56 ms   |             0.36 ms    |      0.14 ms
     1024   |      256    |     0.56 ms   |             0.35 ms    |      0.17 ms
     1024   |     1024    |     0.57 ms   |             0.35 ms    |      0.22 ms
    16384   |        1    |     0.76 ms   |             0.41 ms    |      0.21 ms
    16384   |        4    |     0.75 ms   |             0.42 ms    |      0.14 ms
    16384   |       16    |     0.74 ms   |             0.45 ms    |      0.16 ms
    16384   |       64    |     0.75 ms   |             0.43 ms    |      0.22 ms
    16384   |      256    |     0.74 ms   |             0.42 ms    |      0.41 ms
    16384   |     1024    |     0.80 ms   |             0.44 ms    |      1.18 ms
  2097152   |        1    |     1.02 ms   |             0.53 ms    |      0.61 ms
  2097152   |        4    |     0.98 ms   |             0.55 ms    |      0.91 ms
  2097152   |       16    |     1.00 ms   |             0.55 ms    |      2.11 ms
  2097152   |       64    |     1.01 ms   |             0.55 ms    |      7.77 ms
  2097152   |      256    |     1.02 ms   |             0.56 ms    |     34.49 ms
  2097152   |     1024    |     1.03 ms   |             0.57 ms    |    132.58 ms

The TL;DR is that if you want to search in less than 16k elements or if you only searching 1-2 elements it doesn't make much sense in using the binary search. If otoh you are searching for a lot of elements in a large sorted array (in the millions of elements), then you can expect 100x improvement using binary search :-) .

angeloskath avatar Sep 19 '24 20:09 angeloskath

Thank you @angeloskath for the info. I think based on the application point of view there will be rare cases involving more than 16k elements. Shall I implement linear search then? Or another idea is to implement both and use either based on condition such as sorted size.

Saanidhyavats avatar Sep 21 '24 04:09 Saanidhyavats

@awni in which particular module shall we implement this feature?

Saanidhyavats avatar Sep 26 '24 04:09 Saanidhyavats

Is there any new update on the searchsorted function with mlx? Or, are we still relying on

def searchsorted(a, b):
    return (a[None, :] < b[:, None]).sum(axis=1)

for size<16k ?

hemantaph avatar Jul 21 '25 18:07 hemantaph

There is no update. You can use the above for now. If nyone wants to work on the issue it's open... we'll try to get to it otherwise when we have B/W.

awni avatar Jul 21 '25 19:07 awni