mlx
mlx copied to clipboard
[FEATURE] MLX equivalent for jax.lax.scan and jax.lax.while_loop
Would it be possible to have MLX equivalents for jax.lax.scan and jax.lax.while_loop?
Matching algorithms like Hungarian matching for object detection (DETR) need boolean arrays to be evaluated to correctly compute optimal matchings. As far as my understanding goes, this would make it improbable to write an MLX specific implementation which can be compiled to optimise the end-to-end training process.
What I'm not sure about is how much performance benefit can be obtained if the loss computation (that depends on matching) itself can be compiled. From what I see on the documentation, compiling can yield significant performance improvements.