Tristan Bilot
Tristan Bilot
Hmm I see, one workaround would require running a first kernel to count the number of values that verify the condition, then allocate the appropriate memory and then running the...
> * In Jax you can't JIT through ops like `nonzero` unless you explicitly pass a `size` parameter. > * I haven't looked at PyTorch MPS but my guess is...
Any updates?
It would be great to have updated benchmarks using the latest version of MLX. Everyone is free to open a PR to integrate results with more recent versions!