give users a way to type-query effects, e.g. in eval_shape
We use jax.eval_shape as a way to query function types, but it doesn't surface effect types. We need a way to do that, either built into jax.eval_shape or something else.
jax.infer_type?
FWIW I think there are basically 4 things a user might ask for when abstractly interpreting:
- output struct + avals
- jaxpr
- effects
- closed-over constants.
Right now eval_shape gives you just the first, make_jaxpr gives you the first two (three?), and closure_convert gives you some of the last (only inexact arrays, but in the presence of e.g. custom_vmap you need all arrays).
Rather than a new API for specifically effects, I'd probably argue for a unified API that offers any combination of these, ideally in an "optimal" way, e.g. not forming a jaxpr if all that is needed is the effect set.
bump, I could really use a way to get a jaxpr abstractly without needing to pass inputs, because in my testing cloudpickle of jax jitted functions today, I found the
- functionality is identical before and after cloudpickling
- jaxprs are identical before and after cloudpickling
- dict gets wiped
- hash is thus different
To test code to serialize / deserialize jitted functions, we to make equality checks.
Problem is, this only works with the jaxpr, but you need to pass an input to make the jaxpr, but you need to know the function to know the input shape, so there's a chicken-or-egg problem in this (passing) test
def jit_fns_eq(
j1: PjitFunction,
j2: PjitFunction,
*,
debug: bool = DEBUG,
) -> bool:
"returns True if two jax.jit functions have the same jaxpr"
j1_jaxpr = jax.make_jaxpr(j1)(1.0)
j2_jaxpr = jax.make_jaxpr(j2)(1.0)
if debug:
print("j1_jaxpr", j1_jaxpr)
print("j2_jaxpr", j2_jaxpr)
return repr(j1_jaxpr) == repr(j2_jaxpr)
def test_jit_fns_eq():
def jnp_func1(x):
return jax.numpy.sin(jax.numpy.cos(x))
def jnp_func2(x):
return jax.numpy.cos(jax.numpy.sin(x))
jitted1 = jax.jit(jnp_func1)
jitted2 = jax.jit(jnp_func2)
assert jit_fns_eq(jitted1, jitted1), "false negative on equality check 1,1"
assert jit_fns_eq(jitted2, jitted2), "false negative on equality check 2,2"
assert not jit_fns_eq(jitted1, jitted2), "false positive on equality check 1,2"
assert not jit_fns_eq(jitted2, jitted1), "false positive on equality check 2,1"
print("PASS!")
test_jit_fns_eq()
This approach would only work for jax functions that accept a scalar input, because i wouldn't know what shape of inputs to pass for arbitrary functions unless i make some other thing to hold that information.
+ def jnp_func3(x, y):
+ return jnp_func1(x) * jnp_func2(y)
+ assert jit_fns_eq(jitted3, jitted3), "false negative on equality check 3,3"
i.e. the requirement to pass inputs to make a jaxpr makes it hard to make a signature-agnostic equality function for jitted functions
you can get the number of parameters, but they could be arbitrary rank and dtype and it's not possible to know that,
temporary hackish sort of working equality for jit fns
def jit_fns_eq(
j1: PjitFunction,
j2: PjitFunction,
*,
debug: bool = DEBUG,
) -> bool:
"returns True if two jax.jit functions have the same jaxpr"
- j1_jaxpr = jax.make_jaxpr(j1)(1.0)
- j2_jaxpr = jax.make_jaxpr(j2)(1.0)
+ sig1 = inspect.signature(j1)
+ sig2 = inspect.signature(j2)
+ n_args_1 = len(sig1.parameters)
+ n_args_2 = len(sig2.parameters)
+ j1_jaxpr = jax.make_jaxpr(j1)(*(1.0 for _ in range(n_args_1))) # fails if rank/shape wrong?
+ j2_jaxpr = jax.make_jaxpr(j2)(*(1.0 for _ in range(n_args_2)))
if debug:
print("j1_jaxpr", j1_jaxpr)
print("j2_jaxpr", j2_jaxpr)
return repr(j1_jaxpr) == repr(j2_jaxpr)
anybody know an easy way here to reproduce the issue where the rank or shape of the arrays would need to be specific for some function?
ah, nevermind, standard vector matrix einsum works.
def jnp_func4(x, y):
return jnp.einsum("x,xy->y", x, y)
jitted4 = jax.jit(jnp_func4)
rng = jax.random.key(42)
rng1, rng2, rng3 = jax.random.split(rng, 3)
x1 = jax.random.uniform(rng1, (2,))
y1 = jax.random.uniform(rng2, (2, 3))
a1 = jnp_func4(x1, y1)
print(a1)
print("a1 ok")
assert jit_fns_eq(jitted4, jitted4), "false negative on equality check 4,4"
the space of input signatures is way too big, need some way to get jaxpr without knowing input signature
tldr: wtb jaxpr: Jaxpr = jaxpr_from_jit_fn(jit_fn) + jit_fn: PjitFunction = jit_fn_from_jaxpr(jaxpr) bijection