jax icon indicating copy to clipboard operation
jax copied to clipboard

Propagate the available axes to the partitioning function

Open superbobry opened this issue 1 year ago • 1 comments

See #20864 for more context and the added test for a reproducer.

superbobry avatar Apr 23 '24 13:04 superbobry

Wonderful! I tested this (cherry picked on top of 0.4.26) with my ring attention algorithm and it looks like it works. :)

nshepperd avatar Apr 24 '24 16:04 nshepperd