probability
probability copied to clipboard
inference gym / numpyro compatibility
I'm looking to use Inference Gym targets in numpyro, but I'm running into issues I believe because there are numpy arrays in the Inference Gym model init, which causes tracer conversion errors in numpyro/jax
Any ideas how to get around this? I cant tell how the inference_gym.using_jax module works, but I was hoping that it would change the arrays to be initialized as jax arrays and not numpy
import jax
import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS
from inference_gym import using_jax as gym
class Banana(dist.Distribution):
arg_constraints = {"ndims": dist.constraints.positive_integer, "curvature": dist.constraints.real}
support = dist.constraints.real_vector
pytree_data_fields = ("ndims", "curvature")
def __init__(self, ndims, curvature):
self.ndims = ndims
self.curvature = curvature
self.gym_dist = gym.targets.Banana(ndims=ndims, curvature=curvature)
super().__init__(event_shape=(ndims,))
def sample(self, key, sample_shape=()):
return self.gym_dist.sample(seed=key, sample_shape=sample_shape)
def log_prob(self, value):
return self.gym_dist._unnormalized_log_prob(value)
samples = Banana(ndims=3, curvature=0.03).sample(jax.random.PRNGKey(0), sample_shape=(100,))
def model(X):
curvature = numpyro.sample("curvature", dist.Beta(1,30))
return numpyro.sample("obs", Banana(ndims=3, curvature=curvature), obs=X)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), X=samples)
Here's the full traceback
Traceback
File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_28885/3778384098.py", line 37, in
mcmc.run(jax.random.PRNGKey(0), X=samples) File "/.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py", line 702, in run states_flat, last_state = partial_map_fn(map_args) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py", line 465, in _single_chain_mcmc new_init_state = self.sampler.init( ^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py", line 749, in init init_params = self._init_state( ^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py", line 693, in _init_state ) = initialize_model( ^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 713, in initialize_model (init_params, pe, grad), is_valid = find_valid_initial_params( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 447, in find_valid_initial_params (init_params, pe, z_grad), is_valid = _find_valid_params( ^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 433, in _find_valid_params _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state) ^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 417, in body_fn pe, z_grad = value_and_grad(potential_fn)(params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback return fun(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/api.py", line 468, in value_and_grad_f ans, vjp_py = _vjp(f_partial, *dyn_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/api.py", line 1975, in _vjp out_primals, vjp = ad.vjp(flat_fun, primals_flat) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 252, in vjp out_primals, pvals, jaxpr, consts = linearize(traceable, *primals) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 237, in linearize jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 574, in trace_to_jaxpr_nounits jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) ^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped return self.f_transformed(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 587, in trace_to_subjaxpr_nounits out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py", line 616, in _trace_to_subjaxpr_nounits ans = f(*in_args) ^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/api_util.py", line 72, in flatten_fun ans = f(*py_args, **py_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 78, in jvpfun out_primals, out_tangents = f(tag, primals, tangents) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 115, in jvp_subtrace ans = f(*in_tracers) ^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/api_util.py", line 88, in flatten_fun_nokwargs ans = f(*py_args) ^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/api_util.py", line 292, in _argnums_partial return fun(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 299, in potential_energy log_joint, model_trace = log_density( ^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/infer/util.py", line 70, in log_density model_trace = trace(model).get_trace(*model_args, **model_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/handlers.py", line 186, in get_trace self(*args, **kwargs) File "/.venv/lib/python3.11/site-packages/numpyro/primitives.py", line 105, in call return self.fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/primitives.py", line 105, in call return self.fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/primitives.py", line 105, in call return self.fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ [Previous line repeated 3 more times] File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_28885/3778384098.py", line 31, in model return numpyro.sample("obs", Banana(ndims=3, curvature=curvature), obs=X) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/numpyro/distributions/distribution.py", line 100, in call return super().call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_28885/3778384098.py", line 17, in init self.gym_dist = gym.targets.Banana(ndims=ndims, curvature=curvature) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/inference_gym/targets/banana.py", line 116, in init [10.] + [np.sqrt(1. + 2 * curvature2 * 10.**4)] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.venv/lib/python3.11/site-packages/jax/_src/core.py", line 692, in array raise TracerArrayConversionError(self)
While this particular issue is fixable in principle (see below), the core issue is that this Inference Gym targets are not intended to be used this way. The targets are high level constructs that can do IO and other things not compatible with jitted computation in their initializer. They're not "distributions" in the sense of being building blocks to constructing larger probabilistic models.
If you want a local fix, edit the /.venv/lib/python3.11/site-packages/inference_gym/targets/banana.py source code on that line to look like:
ground_truth_standard_deviation=tf.constant(
[10.] + [tf.sqrt(1. + 2 * curvature**2 * 10.**4)] +
[1.] * (ndims - 2)),
That's just a local workaround. It's unclear what the proper solution would look like because there's a lot of corner cases.