mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[FEATURE] MLX equivalent for jax.lax.scan and jax.lax.while_loop

Open sachinraja13 opened this issue 1 year ago • 0 comments

Would it be possible to have MLX equivalents for jax.lax.scan and jax.lax.while_loop?

Matching algorithms like Hungarian matching for object detection (DETR) need boolean arrays to be evaluated to correctly compute optimal matchings. As far as my understanding goes, this would make it improbable to write an MLX specific implementation which can be compiled to optimise the end-to-end training process.

What I'm not sure about is how much performance benefit can be obtained if the loss computation (that depends on matching) itself can be compiled. From what I see on the documentation, compiling can yield significant performance improvements.

sachinraja13 avatar Sep 27 '24 13:09 sachinraja13