mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Alternative to nonzero()

Open TristanBilot opened this issue 1 year ago • 5 comments

This function is still not implemented.

  • Do you know an efficient workaround until the feature is implemented?
  • Regarding the implementation, my first intuition is that we need a dedicated kernel to get the indices where the condition is verified in parallel. Is there another way without using a new kernel? I can work on this feature.

Example

Goal:

a = mx.array([1, 2, 3, 4, 5])
indices = mx.nonzero(a < 2) # -> shall return [3, 4]

Current workaround:

a = mx.array([1, 2, 3, 4, 5])
indices = [i for i, x in enumerate(a) if x < 2]

TristanBilot avatar Jan 27 '24 16:01 TristanBilot

It's pretty unlikely we will implement this in the near future because the output shape depends on the input data. MLX is currently not setup well to deal with operations like that.

  • In most cases (but not all) there is a good workaround (usually something like mask and reduce).
  • In the instances when there is not a good workaround, you have to eval the graph. We could expose a nonzero which did that to avoid going through NumPy or Python, but I'm not that comfortable with the idea.

Out of curiousity could you share more details on your computation? I'm wondering if there is a workaround or not (will help us understand how much to prioritize this).

awni avatar Jan 27 '24 17:01 awni

Hmm I see, one workaround would require running a first kernel to count the number of values that verify the condition, then allocate the appropriate memory and then running the actual kernel that computes nonzero(). However, this is not optimal. Do you have any idea how other frameworks handle this scenario?

Our computation is straightforward. We want to get the indices where values are between a given lower and upper bound.

mask = (array >= lower_bound) & (array <= upper_bound)
indices = mx.array([i for i, e in enumerate(mask) if e]) # Here we want to get the indices based on the mask

TristanBilot avatar Jan 27 '24 19:01 TristanBilot

Hmm I see, one workaround would require running a first kernel to count the number of values that verify the condition, then allocate the appropriate memory and then running the actual kernel that computes nonzero(). However, this is not optimal. Do you have any idea how other frameworks handle this scenario?

  • In Jax you can't JIT through ops like nonzero unless you explicitly pass a size parameter.
  • I haven't looked at PyTorch MPS but my guess is if it's supported it does something like what you said which indeed is highly non-optimal.

We want to get the indices where values are between a given lower and upper bound.

Makes sense, and what do you use those for?

awni avatar Jan 27 '24 20:01 awni

  • In Jax you can't JIT through ops like nonzero unless you explicitly pass a size parameter.
  • I haven't looked at PyTorch MPS but my guess is if it's supported it does something like what you said which indeed is highly non-optimal.

Ok, it seems everybody got stuck on this feature ahha.

Makes sense, and what do you use those for?

We finally found a workaround using another architecture, so we can close this issue if you don't think working on nonzero() is a priority yet.

TristanBilot avatar Jan 27 '24 22:01 TristanBilot