jax
jax copied to clipboard
add pjit input-output forwarding rule
When an inner jit simply forwards some of its inputs to outputs, we can prune those outputs and use the caller's value for them.
import jax
@jax.jit
def f(x):
return x, 2 * x
print(jax.make_jaxpr(lambda: f(3))())
Before:
{ lambda ; . let
a:i32[] b:i32[] = pjit[
name=f
jaxpr={ lambda ; c:i32[]. let d:i32[] = mul 2 c in (c, d) }
] 3
in (a, b) }
After:
{ lambda ; . let
a:i32[] = pjit[
name=f
jaxpr={ lambda ; b:i32[]. let c:i32[] = mul 2 b in (c,) }
] 3
in (3, a) }
The motivating application is related to a dynamic shapes experiment, where it's useful to keep static shapes static.