Alternative to nonzero()
This function is still not implemented.
- Do you know an efficient workaround until the feature is implemented?
- Regarding the implementation, my first intuition is that we need a dedicated kernel to get the indices where the condition is verified in parallel. Is there another way without using a new kernel? I can work on this feature.
Example
Goal:
a = mx.array([1, 2, 3, 4, 5])
indices = mx.nonzero(a < 2) # -> shall return [3, 4]
Current workaround:
a = mx.array([1, 2, 3, 4, 5])
indices = [i for i, x in enumerate(a) if x < 2]
It's pretty unlikely we will implement this in the near future because the output shape depends on the input data. MLX is currently not setup well to deal with operations like that.
- In most cases (but not all) there is a good workaround (usually something like mask and reduce).
- In the instances when there is not a good workaround, you have to
evalthe graph. We could expose anonzerowhich did that to avoid going through NumPy or Python, but I'm not that comfortable with the idea.
Out of curiousity could you share more details on your computation? I'm wondering if there is a workaround or not (will help us understand how much to prioritize this).
Hmm I see, one workaround would require running a first kernel to count the number of values that verify the condition, then allocate the appropriate memory and then running the actual kernel that computes nonzero(). However, this is not optimal. Do you have any idea how other frameworks handle this scenario?
Our computation is straightforward. We want to get the indices where values are between a given lower and upper bound.
mask = (array >= lower_bound) & (array <= upper_bound)
indices = mx.array([i for i, e in enumerate(mask) if e]) # Here we want to get the indices based on the mask
Hmm I see, one workaround would require running a first kernel to count the number of values that verify the condition, then allocate the appropriate memory and then running the actual kernel that computes nonzero(). However, this is not optimal. Do you have any idea how other frameworks handle this scenario?
- In Jax you can't JIT through ops like
nonzerounless you explicitly pass asizeparameter. - I haven't looked at PyTorch MPS but my guess is if it's supported it does something like what you said which indeed is highly non-optimal.
We want to get the indices where values are between a given lower and upper bound.
Makes sense, and what do you use those for?
- In Jax you can't JIT through ops like
nonzerounless you explicitly pass asizeparameter.- I haven't looked at PyTorch MPS but my guess is if it's supported it does something like what you said which indeed is highly non-optimal.
Ok, it seems everybody got stuck on this feature ahha.
Makes sense, and what do you use those for?
We finally found a workaround using another architecture, so we can close this issue if you don't think working on nonzero() is a priority yet.