jax
jax copied to clipboard
Partition a Pallas Kernel as its pure Jax Counterpart
I have been writing pallas kernels and has recently been made aware that they are not automatically partitioned.
I guess the general case is complicated; in my own case I am going with the custom_partitioning
decorator.
However in case where the pallas kernel represent a faster version of a jax-primitive-only function, I was hoping to reuse partitioning mechanism from the jax function. To this end, I would like to propose a new decorator partiton_like
, e.g. for a jax function jax_fn
and the pallas kernel based function pallas_fn
, do
@jax.jit
@partial(partition_like, fn=jax_fn) # instead of the full custom_partitioning definition
def pallas_fn (...):
...
Conceptually, I think this is do-able probably by lowering first the jax_fn
with the spmd compiler and then swap in the pallas kernel.
I am willing to contribute but will need more guidance on jax core and actual implementations.
Maybe not quite the same thing, but similar in spirit, it would be nice if the pallas_call could inherit replication rules for use with shard_map, so we don't have to pass check_rep=False
.