jax icon indicating copy to clipboard operation
jax copied to clipboard

[pallas] align interpreter load/store with masked behaviour

Open oliverdutton opened this issue 9 months ago • 1 comments

(and adds stride support)

Fixes https://github.com/google/jax/issues/21143

Implements a jittable masked gather/scatter where for load/store/swap any masked indexing does not occur.

For load it sets any masked indexing to index to the first element in the array instead.

For swap(/store) it also sets masked indexing to the first element (and then deals with special rules to make sure the first element is dealt with correctly)

The currently used dynamic_slices are replaced with explicit index materialisation and gathers/scatters. The advantage of doing it this way is that you can combine it with checkify(f, errors=checkify.index_checks) in interpreter mode to check for any unmasked OOB indexing which is (I think, and believe should be) undefined behaviour.

[apologies this is a reopening of a previous request I'd done badly having not checked contributing.md]

oliverdutton avatar May 18 '24 21:05 oliverdutton