jax icon indicating copy to clipboard operation
jax copied to clipboard

give users a way to type-query effects, e.g. in eval_shape

Open mattjj opened this issue 2 years ago • 4 comments

We use jax.eval_shape as a way to query function types, but it doesn't surface effect types. We need a way to do that, either built into jax.eval_shape or something else.

mattjj avatar Oct 23 '23 18:10 mattjj

jax.infer_type?

sharadmv avatar Oct 25 '23 19:10 sharadmv

FWIW I think there are basically 4 things a user might ask for when abstractly interpreting:

  1. output struct + avals
  2. jaxpr
  3. effects
  4. closed-over constants.

Right now eval_shape gives you just the first, make_jaxpr gives you the first two (three?), and closure_convert gives you some of the last (only inexact arrays, but in the presence of e.g. custom_vmap you need all arrays).

Rather than a new API for specifically effects, I'd probably argue for a unified API that offers any combination of these, ideally in an "optimal" way, e.g. not forming a jaxpr if all that is needed is the effect set.

patrick-kidger avatar Oct 25 '23 19:10 patrick-kidger

bump, I could really use a way to get a jaxpr abstractly without needing to pass inputs, because in my testing cloudpickle of jax jitted functions today, I found the

  • functionality is identical before and after cloudpickling
  • jaxprs are identical before and after cloudpickling
  • dict gets wiped
  • hash is thus different

To test code to serialize / deserialize jitted functions, we to make equality checks.

Problem is, this only works with the jaxpr, but you need to pass an input to make the jaxpr, but you need to know the function to know the input shape, so there's a chicken-or-egg problem in this (passing) test

def jit_fns_eq(
    j1: PjitFunction,
    j2: PjitFunction,
    *,
    debug: bool = DEBUG,
) -> bool:
    "returns True if two jax.jit functions have the same jaxpr"
    j1_jaxpr = jax.make_jaxpr(j1)(1.0)
    j2_jaxpr = jax.make_jaxpr(j2)(1.0)
    if debug:
        print("j1_jaxpr", j1_jaxpr)
        print("j2_jaxpr", j2_jaxpr)
    return repr(j1_jaxpr) == repr(j2_jaxpr)


def test_jit_fns_eq():
    def jnp_func1(x):
        return jax.numpy.sin(jax.numpy.cos(x))

    def jnp_func2(x):
        return jax.numpy.cos(jax.numpy.sin(x))

    jitted1 = jax.jit(jnp_func1)
    jitted2 = jax.jit(jnp_func2)

    assert jit_fns_eq(jitted1, jitted1), "false negative on equality check 1,1"
    assert jit_fns_eq(jitted2, jitted2), "false negative on equality check 2,2"
    assert not jit_fns_eq(jitted1, jitted2), "false positive on equality check 1,2"
    assert not jit_fns_eq(jitted2, jitted1), "false positive on equality check 2,1"
    print("PASS!")


test_jit_fns_eq()

image

This approach would only work for jax functions that accept a scalar input, because i wouldn't know what shape of inputs to pass for arbitrary functions unless i make some other thing to hold that information.

+    def jnp_func3(x, y):
+        return jnp_func1(x) * jnp_func2(y)

+     assert jit_fns_eq(jitted3, jitted3), "false negative on equality check 3,3"

image

i.e. the requirement to pass inputs to make a jaxpr makes it hard to make a signature-agnostic equality function for jitted functions

you can get the number of parameters, but they could be arbitrary rank and dtype and it's not possible to know that,

temporary hackish sort of working equality for jit fns

def jit_fns_eq(
    j1: PjitFunction,
    j2: PjitFunction,
    *,
    debug: bool = DEBUG,
) -> bool:
    "returns True if two jax.jit functions have the same jaxpr"
-    j1_jaxpr = jax.make_jaxpr(j1)(1.0)
-    j2_jaxpr = jax.make_jaxpr(j2)(1.0)
+    sig1 = inspect.signature(j1)
+    sig2 = inspect.signature(j2)
+    n_args_1 = len(sig1.parameters)
+    n_args_2 = len(sig2.parameters)
+    j1_jaxpr = jax.make_jaxpr(j1)(*(1.0 for _ in range(n_args_1))) # fails if rank/shape wrong?
+    j2_jaxpr = jax.make_jaxpr(j2)(*(1.0 for _ in range(n_args_2)))
    if debug:
        print("j1_jaxpr", j1_jaxpr)
        print("j2_jaxpr", j2_jaxpr)
    return repr(j1_jaxpr) == repr(j2_jaxpr)

image

anybody know an easy way here to reproduce the issue where the rank or shape of the arrays would need to be specific for some function?

bionicles avatar Jun 28 '24 18:06 bionicles

ah, nevermind, standard vector matrix einsum works.

    def jnp_func4(x, y):
        return jnp.einsum("x,xy->y", x, y)

    jitted4 = jax.jit(jnp_func4)

    rng = jax.random.key(42)
    rng1, rng2, rng3 = jax.random.split(rng, 3)
    x1 = jax.random.uniform(rng1, (2,))
    y1 = jax.random.uniform(rng2, (2, 3))
    a1 = jnp_func4(x1, y1)
    print(a1)
    print("a1 ok")

    assert jit_fns_eq(jitted4, jitted4), "false negative on equality check 4,4"

image

the space of input signatures is way too big, need some way to get jaxpr without knowing input signature tldr: wtb jaxpr: Jaxpr = jaxpr_from_jit_fn(jit_fn) + jit_fn: PjitFunction = jit_fn_from_jaxpr(jaxpr) bijection

bionicles avatar Jun 28 '24 18:06 bionicles