jax icon indicating copy to clipboard operation
jax copied to clipboard

Partition a Pallas Kernel as its pure Jax Counterpart

Open luyug opened this issue 1 year ago • 1 comments

I have been writing pallas kernels and has recently been made aware that they are not automatically partitioned.

I guess the general case is complicated; in my own case I am going with the custom_partitioning decorator.

However in case where the pallas kernel represent a faster version of a jax-primitive-only function, I was hoping to reuse partitioning mechanism from the jax function. To this end, I would like to propose a new decorator partiton_like, e.g. for a jax function jax_fn and the pallas kernel based function pallas_fn, do

@jax.jit
@partial(partition_like, fn=jax_fn)  # instead of the full custom_partitioning definition
def pallas_fn (...):
    ...

Conceptually, I think this is do-able probably by lowering first the jax_fn with the spmd compiler and then swap in the pallas kernel. I am willing to contribute but will need more guidance on jax core and actual implementations.

luyug avatar Feb 12 '24 10:02 luyug

Maybe not quite the same thing, but similar in spirit, it would be nice if the pallas_call could inherit replication rules for use with shard_map, so we don't have to pass check_rep=False.

ppham27 avatar Feb 14 '24 05:02 ppham27