jax icon indicating copy to clipboard operation
jax copied to clipboard

Allow tuple inputs to `scatter_dimension` in `jax.lax.psum_scatter`

Open vvvm23 opened this issue 1 year ago • 0 comments

Would it be possible to add tuple inputs to scatter_dimension in jax.lax.psum_scatter?

For example, we have an array of shape (batch_size, sequence_length, dim) sharded as (dp, sp, tp) we pass to a shard_map function. We perform some operations and want to do a psum across both sp and tp and also want to return an array with the same shardings, so we should also scatter along the last two axes.

This is not easily achievable currently as scatter_dimension only takes integer inputs. We can get the effect by doing psum_scatter across the (sp, tp) axes and then scattering along either 1 or 2, then doing a dynamic_slice_in_dim on the other.

If it did accept tuple inputs, we could simply set scatter_dimension(s) to (1, 2)

Hope that makes sense, let me know if you want a code example to illustrate the point further.

vvvm23 avatar Sep 12 '24 13:09 vvvm23