[stablehlo] Add a pass to force scatters to legal areas
stablehlo.scatter allows scattering outside of the tensor bounds. In these cases the scatter is a no-op, even if only partially outside of the bounds. Make this change by validating indices, forcing within bounds and conditionally updating values. This also forces non-unique indices as multiple out of bounds regions may overlap when conditionally applied.
Do we know how terrible of code this generates? This is a lot of stuff to do per element and I suspect it may break vectorization and such - we'd at least want to ensure it all ends up in a single dispatch region.
Do we know how terrible of code this generates? This is a lot of stuff to do per element and I suspect it may break vectorization and such - we'd at least want to ensure it all ends up in a single dispatch region.
I am pretty ignorant to how dispatch region formation plays with the LinalgExt scatter operation. If we can fuse in the index / conditional dispatches then it seems plausible the whole thing can merge into a single dispatch however I would need to check. Right now we are blocked on determining what the dispatch region formation looks like as linalg_ext.scatter only supports a single target tensor. Working on fixing that in a followup.
Vectorization is scatter dependent. Assuming we are writing some continuous blocks we should be safe for vectorization. Even without these changes vectorization borks if we are scattering along the bottom most dimension.
Overall I was hoping to conditionally enable this for the PJRT backend and have a flag for iree-compile. I feel like this behavior is largely unused by an reasonable model.
sounds good - @MaheshRavishankar may be able to speak to the formation questions - I mostly am just curious if we're sending all scatters down a scalar path with multiple dispatches/materialized transients between dispatches/etc. If this is needed for correctness then it's needed, but would be good to track making it right if it's not doing the right thing already!
sounds good - @MaheshRavishankar may be able to speak to the formation questions - I mostly am just curious if we're sending all scatters down a scalar path with multiple dispatches/materialized transients between dispatches/etc. If this is needed for correctness then it's needed, but would be good to track making it right if it's not doing the right thing already!
Realized there is one mistake preventing single dispatch region formation. We do a reduction across the indices to determine whether all are in bounds. I am willing to fix this in a follow up, I would just want to get this initially working with some linalg_ext changes and validate.
yeah, that reduction is likely to be an issue - we don't really want to serialize everything like that - is there a reason to reduce?
(I'd expect this to be output[i] = indices[i] < limit ? input[indices[i]] : 0; or something per-element, not requiring any global tensor reduction)
yeah, that reduction is likely to be an issue - we don't really want to serialize everything like that - is there a reason to reduce?
(I'd expect this to be
output[i] = indices[i] < limit ? input[indices[i]] : 0;or something per-element, not requiring any global tensor reduction)
I expect it will just be a bunch of slices feeding into a single linalg.generic with the per-index check inside. Honestly this implementation was the first I had any confidence in and the multi-slice larger generic approach was simply harder to read. Happy to switch to it before I turn this on by default, my priorities right now are addressing the correctness / crashing issues first, then tweaking things for performance.
@rsuderman still valid?