Ilia Sergachev
Ilia Sergachev
The added structure Result will be used to add support of slicing.
CuDnnThunk currently used for GEMM fusions is capable of executing arbitrary cuDNN graphs. Moving FMHA to use it lets remove lots of specialized runtime code. The overview of the change...
This will require https://github.com/openxla/xla/pull/15399 to work. Code for jax/_src/cudnn/fusion.py provided by @hawkinsp.
The latter one is [deprecated](https://github.com/openxla/xla/commit/8ad80bd15bd3ce237f9e999234bbb3c7a1ff3653) in XLA.
Disable XLA autotuning fallback to cuBLAS so that the tested fusion always executes through cuDNN.
📝 Summary of Changes All-gathers can only run on the major-most physical dimension - concatenating buffers from ranks. When an all-gather on a logical dimension index > 0 is requested,...