jax
jax copied to clipboard
Propagate the available axes to the partitioning function
See #20864 for more context and the added test for a reproducer.
Wonderful! I tested this (cherry picked on top of 0.4.26) with my ring attention algorithm and it looks like it works. :)