jax
jax copied to clipboard
[pallas] align interpreter load/store with masked behaviour
(and adds stride support)
Fixes https://github.com/google/jax/issues/21143
Implements a jittable masked gather/scatter where for load/store/swap any masked indexing does not occur.
For load it sets any masked indexing to index to the first element in the array instead.
For swap(/store) it also sets masked indexing to the first element (and then deals with special rules to make sure the first element is dealt with correctly)
The currently used dynamic_slices are replaced with explicit index materialisation and gathers/scatters. The advantage of doing it this way is that you can combine it with checkify(f, errors=checkify.index_checks) in interpreter mode to check for any unmasked OOB indexing which is (I think, and believe should be) undefined behaviour.
[apologies this is a reopening of a previous request I'd done badly having not checked contributing.md]