jax icon indicating copy to clipboard operation
jax copied to clipboard

add pjit input-output forwarding rule

Open mattjj opened this issue 9 months ago • 1 comments

When an inner jit simply forwards some of its inputs to outputs, we can prune those outputs and use the caller's value for them.

import jax

@jax.jit
def f(x):
  return x, 2 * x

print(jax.make_jaxpr(lambda: f(3))())

Before:

{ lambda ; . let
    a:i32[] b:i32[] = pjit[
      name=f
      jaxpr={ lambda ; c:i32[]. let d:i32[] = mul 2 c in (c, d) }
    ] 3
  in (a, b) }

After:

{ lambda ; . let
    a:i32[] = pjit[
      name=f
      jaxpr={ lambda ; b:i32[]. let c:i32[] = mul 2 b in (c,) }
    ] 3
  in (3, a) }

The motivating application is related to a dynamic shapes experiment, where it's useful to keep static shapes static.

mattjj avatar May 24 '24 04:05 mattjj