Improve the interface to JAX-compiled functions
Consider the following probabilistic model:
import aesara.tensor as at
X_at = at.matrix('X')
srng = at.random.RandomStream(0)
tau_rv = srng.halfcauchy(0, 1)
lambda_rv = srng.halfcauchy(0, 1, size=X_at.shape[-1])
beta_rv = srng.normal(0, tau_rv * lambda_rv, size=X_at.shape[-1])
eta = X_at @ beta_rv
p = at.sigmoid(-eta)
Y_rv = srng.bernoulli(p)
In a typical bayesian modeling workflow, we want to be able to generate and use three functions:
- A function that samples from the prior joint distribution (initialize the values)
- A function that computes the model's logdensity
- A function that computes posterior predictive sampling
In a workflow that uses JAX we typically want to be able to use jax.vmap with (1) and (3), and jax.grad with (2). While it is possible to do this with Aesara-compiled function, we need to go through unnecessary levels of indirections.
Current behavior
First, to use jax.grad we must use the vm.jit_fn attribute of the Aesara-compiled function and wrap the function so it returns a single value instead of a 1-element tuple:
import aesara
import aeppl
import jax
import numpy as np
logprob, vvs = aeppl.joint_logprob(tau_rv, lambda_rv, beta_rv, Y_rv)
logdensity_fn = aesara.function([X_at] + list(vvs), logprob, mode="JAX")
try:
jax.grad(logdensity_fn)(np.ones((3,2)), 1., np.ones(2), np.ones(2), np.ones(2))
except Exception as e:
print(e)
# Bad input argument to aesara function with name "<stdin>:22" at index 0 (0-based).
# Backtrace when that variable is created:
# File "<stdin>", line 3, in <module>
# The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([[1. 1.]
# [1. 1.]
# [1. 1.]], dtype=float64)>with<JVPTrace(level=2/0)> with
# primal = array([[1., 1.],
# [1., 1.],
# [1., 1.]])
# tangent = Traced<ShapedArray(float64[3,2])>with<JaxprTrace(level=1/0)> with
# pval = (ShapedArray(float64[3,2]), None)
# recipe = LambdaBinding()
# See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
logdensity_jit_fn = aesara.function([X_at] + list(vvs), logprob, mode="JAX").vm.jit_fn
try:
jax.grad(logdensity_jit_fn)(np.ones((3,2)), np.array(1.), np.ones(2), np.ones(2), np.ones(3))
except Exception as e:
print(e)
# Gradient only defined for scalar-output functions. Output was (Array(-12.65285075, dtype=float64),).
def logdensity_squeezed_jit_fn(*x):
return logdensity_jit_fn(*x)[0]
print(jax.grad(logdensity_squeezed_jit_fn)(np.ones((3,2)), np.array(1.), np.ones(2), np.ones(2), np.ones(3)))
# [[-0.88079708 -0.88079708]
# [-0.88079708 -0.88079708]
# [-0.88079708 -0.88079708]]
To be able to use jax.vmap to sample multiple values from the prior distribution / do posterior predictive sampling we must also use the jit-compiled function directly. In addition, we must pass one PRNGKey per random variable in the graph which does not reflect the RandomStream mechanism, wrap them in a dictionary with the same structure as the internal random state, and pass them as the last arguments (first argument is idiomatic in JAX):
prior_sample_fn = aesara.function([X_at], [tau_rv, lambda_rv, beta_rv], mode="JAX").vm.jit_fn
rng_key = jax.random.PRNGKey(0)
key1, key2, key3 = jax.random.split(rng_key, 3)
print(prior_sample_fn(np.ones((2,3)), {"jax_state": key1}, {"jax_state": key2}, {"jax_state": key3}))
Expected behavior
I would expect to be able to use the compiled function just like any JAX function:
logdensity_fn = aesara.function([X_at] + list(vvs), logprob, mode="JAX")
jax.grad(logdensity_fn)(np.ones((3,2)), 1., np.ones(2), np.ones(2), np.ones(2))
With the assumption that for random functions the first argument must be a JAX PRNGKey:
prior_sample_fn = aesara.function([X_at], [tau_rv, lambda_rv, beta_rv], mode="JAX").vm.jit_fn
rng_key = jax.random.PRNGKey(0)
keys = jax.random.split(rng_key, 100)
samples = jax.vmap(prior_sample_fn, in_axes=(0, None))(keys, np.ones((2,3)))
Proposals
To make the compiled functions truly compatible with JAX and the rest of its ecosystem I suggest the following changes:
-
Function.__call__should be immediately compatible with JAX; we shouldn't have to fetch thevm.jit_fnattribute. An Aesara function is supposed to manage other things like updates ofSharedVariables, but I am not sure this is necessary for JAX-compiled functions. Furthermore, if it is supposed to do something that normally cannot be done with JAX then it would be preferable to fail during transpilation rather than output a function that cannot be composed with the rest of the JAX ecosystem; -
Function.__call__should not return a tuple when there is a single output; - The internal random state should be represented by the
rng_keydirectly, and not a dictionary that contains the key; - When there are several random variables in the graph, we should only require one PRNG key;
- The PRNG key should be passed as the first input
- Do not return the updated PRNG keys with the functions that take PRNG keys as inputs;
Related to https://github.com/aesara-devs/aesara/issues/1194
To make the compiled functions truly compatible with JAX and the rest of its ecosystem I suggest the following changes: Function.call should be immediately compatible with JAX; we shouldn't have to fetch the
vm.jit_fnattribute. An Aesara function is supposed to manage other things like updates ofSharedVariables, but I am not sure this is necessary for JAX-compiled functions. Furthermore, if it is supposed to do something that normally cannot be done with JAX then it would be preferable to fail during transpilation rather than output a function that cannot be composed with the rest of the JAX ecosystem;
This sounds like a new/different compilation interface and/or approach, because it basically foregoes the all Function machinery itself. We might be able to reuse the aesara.function entry-point and provide such a thing via a keyword option—or something similar.
Function.call should not return a tuple when there is a single output;
Yeah, I'm not a fan of that either, and it shows up all over the place—often unnecessarily complicating our (static) typing.
The internal random state should be represented by the rng_key directly, and not a dictionary that contains the key;
I recall that having to do with our generalized RandomStateType interface, and the need for extra, sometimes unused things—like a bit_generator entry—in order for the custom JAX random statesstate objects (i.e. the dicts) to pass as a valid RandomStateTypes. This interface was shared with our work-arounds for Numba, but, once we merge #1245, that requirement may no longer be necessary. Regardless, there may be another approach to that whole connection with Aesara random state types that removes the need for a dict.
When there are several random variables in the graph, we should only require one PRNG key; A single one that's split into others, no?
Do not return the updated PRNG keys with the functions that take PRNG keys as inputs;
That sounds like it's needed in order to match the outputs provided/expected of a RandomVariable-constructed node, or are you referring to something else?
This sounds like a new/different compilation interface and/or approach, because it basically foregoes the all
Functionmachinery itself. We might be able to reuse theaesara.functionentry-point and provide such a thing via a keyword option—or something similar.
I would indeed like to heavily refactor Function in the case of JAX-compiled functions. The introspection capabilities are nice, but there is no need for the VM-like stuff.
A first step could be to define JAXFunction within the JAX linker, which contains compilation information but can also be called like any JAX function. We can add e.g. aesara.link.jax.function to build JAXFunctions from graphs, and leave the current aesara.function unchanged.
Function.__call__should not return a tuple when there is a single output; Yeah, I'm not a fan of that either, and it shows up all over the place—often unnecessarily complicating our (static) typing.
Surel That behavior comes from fgraph_to_python, right?
- The internal random state should be represented by the
rng_keydirectly, and not a dictionary that contains the key; I recall that having to do with our generalizedRandomStateTypeinterface, and the need for extra, sometimes unused things—like abit_generatorentry—in order for the custom JAX random state objects (i.e. thedicts) to pass as validRandomStateTypes. This interface was shared with our work-arounds for Numba, but, once we merge #1245, that requirement may no longer be necessary. Regardless, there may be another approach to that whole connection with Aesara random state types that removes the need for adict.
I think that for now a rewrite to bypass the check that is currently performed will do. Nevertheless this may be a sign that something is not quite right on the aesara.tensor.random side.
- When there are several random variables in the graph, we should only require one PRNG key; A single one that's split into others, no?
- Do not return the updated PRNG keys with the functions that take PRNG keys as inputs; That sounds like it's needed in order to match the outputs provided/expected of a
RandomVariable-constructed node, or are you referring to something else?
These two are actually related. When in JAX you would typically write:
import jax
def sample(rng_key):
key_1, key_2 = jax.random.split(rng_key)
a = fn1(key_1)
b = fn2(key_2)
return a, b
Aesara returns the following function
import jax
def sample(key_1, key_2):
new_key_1, _ = jax.random.split(key_1)
a = fn1(new_key_1)
new_key_2, _ = jax.random.split(key_2)
b = fn2(new_key_2)
return a, b, new_key_1, new_key_2
which is correct, but definitely not expected by JAX users.