jax
jax copied to clipboard
[Pallas] Fix integer array indexing
Fixes https://github.com/google/jax/issues/22783
Note: the code works now but still need some clean up
Blocked by https://github.com/google/jax/pull/23534
Closing due to merge conflict caused by https://github.com/google/jax/pull/23583. Will make another PR.
Superseded by https://github.com/google/jax/pull/23758