jax icon indicating copy to clipboard operation
jax copied to clipboard

Improve Custom Partioning API Documentation

Open ASKabalan opened this issue 1 year ago • 1 comments

I am trying to use jax.experimental.custom_partitioning for a 3D FFT Primitive.

I am trying to get inspired by the example in the Public API

I feel that there are a few things missing.

In the first example :

@custom_partitioning
def f(*args):
  return ...

def propagate_user_sharding(mesh, user_shape):
  '''Update the sharding of the op from a user's shape.sharding.'''
  user_sharding = jax.tree_map(lambda x: x.sharding, user_shape)

def partition(mesh, arg_shapes, result_shape):
  def lower_fn(*args):
    ... builds computation on per-device shapes ...
  result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
  arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
  # result_sharding and arg_shardings may optionally be modified and the
  # partitioner will insert collectives to reshape.
  return mesh, lower_fn, result_sharding, arg_shardings

def infer_sharding_from_operands(mesh, arg_shapes, shape):
  '''Compute the result sharding from the sharding of the operands.'''
  arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)


f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)

The example rules do not return values, so we cannot infer what is the returned value and the type for each rule (plus typo with result_shardings being reused as result_sharding)
The lower_fn means that we call our primitive and have access to sharding info + use collectives if needed just like this example am I correct ?

It would also be usefull to show an example of static argnums like this one

The signature has shape as the third argument which is vague because in the second example it is called result_shape :

def infer_sharding_from_operands(mesh, arg_shapes, shape):

For the FFT example I find that the return of partition and infer_sharding_from_operands very ambiguous and It would much better to give an example that has an output sharding different from the input sharding.

Though I myself am working on a 3D FFT, I find it more useful to have a simple matmul example on two matrices with the output shape different from the two input shape (thus the sharding is not the same).

After I had a chat with someone from NVIDIA JAX team who kindly explained somethings to me, I understood a bit more but I still can't help to notice that infer_sharding_from_operands that takes mesh, arg_shapes, result_shape and partition taking the same arguments and returning basically the same thing minus the mesh and the primitive call kind of redundant.

The person from NVIDIA said that if they missmatch this could cause unwanted collective call.\

My questions are :

  • How can we define the expected output sharding or shape from the primitive similarly to the custom call ir.RankedTensorType and the layouts
  • How can we set an output sharding that would tell the partionner to use collectives to reshard our output
    • for example in a matmul case we can tell the partitionner to set the output sharding to be like what the input was to prepare for the next step
  • How can we override user sharding .. there is an example here but It confuses me because the partition input is not the same as the infered overriden? input.
  • What does propagate_user_sharding serve? it is tested here but it just returns the sharding just like the default case, so I don't understand its purpose

I noticed that the three functions are directly included in the XLA codebase, but when it comes to JAX I don't feel that they are all necessary / I might not have understood correctly their purpose.

Thank you for your help

ASKabalan avatar Feb 19 '24 20:02 ASKabalan

Note, an example is being added here: https://github.com/google/jax/pull/20179

nouiz avatar Mar 28 '24 14:03 nouiz