xla
xla copied to clipboard
Handle multiple users in all-gather dynamic-slice cancellation. Add CancelAllGatherDynamicSlice pass
I have found in some models that have poor SPMD partitioning the below pattern.
all-gather.1 = all-gather(x)
dot.1 = dot(all-gather.1, y)
dynamic-slice.1 = dynamic-slice(all-gather.1) // can be cancelled
In this case, the all-gather has multiple users but the dynamic-slice can be cancelled. This is applicable to all-reduce and reduce-scatter also. My changes now support multiple users, but it also depends how this utility is used by internal TPU compiler and the GPU ReduceScatterCreator pass. My changes assume the cancellation is run like this --
- Find a dynamic-slice
- Check if dynamic-slice can be cancelled
- Delete dynamic-slice but do not delete the collective
- The collective is deleted by the DCE pass if it has no users
The above workflow then supports removing dynamic-slices even if the collective has multiple users. The above is what we are using in our internal Neuron workflow. Interested to hear thoughts on this.