pymc
pymc copied to clipboard
Reshape operation in logp graph not supported in JAX backend
Description of your problem
Please provide a minimal, self-contained, and reproducible example.
import pymc as pm
import pymc.sampling_jax
import numpy as np
import pandas as pd
from aesara import shared, tensor as at
from patsy import dmatrix
rng = np.random.default_rng(0)
size = 2_000
x1 = rng.normal(size=size)
x2 = rng.normal(size=size)
data = pd.DataFrame(
{
"x1": x1,
"x2": x2,
"y": rng.normal(loc=x1+x2, size=size)
}
)
features = ["x1", "x2"]
DEGREES = 3
N_KNOT = 7
df = N_KNOT + DEGREES + 1
mat_str = ""
mat_str_end = " - 1"
mat_str_middle = " + "
np_features = data[features].values
for feature in features:
mat_str += f"bs({feature}, df={df}, degree={DEGREES}){mat_str_middle}"
mat_str = mat_str[:-2] + mat_str_end
basis = dmatrix(mat_str, {feature: np_features[:, i] for i, feature in enumerate(features)})
dmat_data = np.asarray(basis).reshape(np_features.shape[0], np_features.shape[1], -1)
dmat = shared(dmat_data)
with pm.Model() as model:
mutable_data = pm.MutableData("data", np_features)
HALFNORMAL_SCALE = 1. / np.sqrt(1. - 2. / np.pi)
mu = pm.Normal('mu_grw', 0., 1., shape=dmat.shape[1])
delta = pm.Normal('delta_grw', 0., 0.1/2.5, shape=(dmat.shape[1], dmat.shape[2]))
sigma = pm.HalfNormal('sigma_grw', 0.1 * HALFNORMAL_SCALE, shape=dmat.shape[1])
grw = pm.Deterministic('grw', mu[:, None] + sigma[:, None] * delta.cumsum(axis=1))
f = at.tensordot(dmat, grw)
y = pm.MutableData("y", data["y"])
eps = pm.HalfNormal("eps", sigma=1)
normal = pm.Normal("normal", mu=f, sigma=eps, observed=y)
results = pm.sample()
#results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
Please provide the full traceback.
Complete error traceback
With pm.sample():
Works as expected (with Ricardo's fix)
For numpyro/blackjax:
Traceback (most recent call last):
results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 335, in sample_blackjax_nuts
states, _ = map_fn(get_posterior_samples)(keys, init_params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 1485, in vmap_f
out_flat = batching.batch(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 473, in cache_miss
out_flat = xla.xla_call(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1765, in bind
return call_bind(self, fun, *args, **params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/batching.py", line 226, in process_call
vals_out = call_primitive.bind(f_, *vals, **params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1765, in bind
return call_bind(self, fun, *args, **params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 678, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 182, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 285, in memoized_fun
ans = call(fun, *args)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 230, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 272, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1893, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 199, in _blackjax_inference_loop
last_state, kernel, _ = adapt.run(seed, init_position)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/kernels.py", line 567, in run
init_state = algorithm.init(position, logprob_fn, logprob_grad_fn)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 78, in init
potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 995, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 2457, in _vjp
out_primal, out_vjp = ad.vjp(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/ad.py", line 130, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/ad.py", line 119, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 70, in potential_fn
return -logprob_fn(x)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 109, in logp_fn_wrap
return logp_fn(*x)[0]
File "/tmp/tmpwzpirvop", line 32, in jax_funcified_fgraph
auto_129685 = reshape(auto_131175, auto_130149)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/aesara/link/jax/dispatch.py", line 731, in reshape
return jnp.reshape(x, shape)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 727, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 745, in _reshape
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 738, in _compute_newshape
newshape = core.canonicalize_shape(newshape if iterable else [newshape])
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1651, in canonicalize_shape
raise _invalid_shape_error(shape, context)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got [22].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 335, in sample_blackjax_nuts
states, _ = map_fn(get_posterior_samples)(keys, init_params)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 199, in _blackjax_inference_loop
last_state, kernel, _ = adapt.run(seed, init_position)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/kernels.py", line 567, in run
init_state = algorithm.init(position, logprob_fn, logprob_grad_fn)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 78, in init
potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 70, in potential_fn
return -logprob_fn(x)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 109, in logp_fn_wrap
return logp_fn(*x)[0]
File "/tmp/tmpwzpirvop", line 32, in jax_funcified_fgraph
auto_129685 = reshape(auto_131175, auto_130149)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/aesara/link/jax/dispatch.py", line 731, in reshape
return jnp.reshape(x, shape)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 727, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 745, in _reshape
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 738, in _compute_newshape
newshape = core.canonicalize_shape(newshape if iterable else [newshape])
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [22].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
Please provide any additional information below.
Versions and main components
- PyMC/PyMC3 Version: 4.0.1
- Aesara/Theano Version: 2.7.3
- Python Version: 3.9
- Operating system: Linux
- How did you install PyMC/PyMC3: (conda/pip) pip
This should be fixed in version 4.0.1 which we released a couple of days ago. Mind giving it a try?
Just upgraded and I got the same errors as before for both pm.sample() and the jax samplers
Your model is misspecified, namely the sigma=eps
is allowed to be negative which leads to 0 probability when sampling starts. You can change eps prior to a positive-only distribution like HalfNormal, or transform it through pm.math.exp
for example.
Ah apologies, that's a bad modification from my real model to the reproduction script - thanks! With that change (original post edited) pm.sample()
works fine and the jax samplers give the same error
Seems like your model is using reshape at some point which can't be safely jitted. This is a common issue with the JAX backend. D you get the same error with sampe_numpyro_nuts
. I expect so.
We should probably patch this upstream, but in the meantime you can try to run this snippet before you sample with JAX:
import jax
from aesara.graph import Constant
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.shape import Reshape
@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, node, **kwargs):
shape = node.inputs[1]
if isinstance(shape, Constant):
constant_shape = shape.data
def reshape(x, _):
return jax.numpy.reshape(x, constant_shape)
else:
def reshape(x, shape):
return jax.numpy.reshape(x, shape)
return reshape
We should probably patch this upstream, but in the meantime you can try to run this snippet before you sample with JAX:
import jax from aesara.graph import Constant from aesara.link.jax.dispatch import jax_funcify from aesara.tensor.shape import Reshape @jax_funcify.register(Reshape) def jax_funcify_Reshape(op, node, **kwargs): shape = node.inputs[1] if isinstance(shape, Constant): constant_shape = shape.data def reshape(x, _): return jax.numpy.reshape(x, constant_shape) else: def reshape(x, shape): return jax.numpy.reshape(x, shape) return reshape
Ah thanks! Which part of the model should I run that on? I tried re-writing the model to remove what I thought would be the implicit reshapes but got the same error:
import pymc as pm
import pymc.sampling_jax
import numpy as np
import pandas as pd
from aesara import shared, tensor as at
from patsy import dmatrix
rng = np.random.default_rng(0)
size = 2_000
x1 = rng.normal(size=size)
x2 = rng.normal(size=size)
data = pd.DataFrame(
{
"x1": x1,
"x2": x2,
"y": rng.normal(loc=x1+x2, size=size)
}
)
features = ["x1", "x2"]
DEGREES = 3
N_KNOT = 7
df = N_KNOT + DEGREES + 1
mat_str = ""
mat_str_end = " - 1"
mat_str_middle = " + "
np_features = data[features].values
for feature in features:
mat_str += f"bs({feature}, df={df}, degree={DEGREES}){mat_str_middle}"
mat_str = mat_str[:-2] + mat_str_end
basis = dmatrix(mat_str, {feature: np_features[:, i] for i, feature in enumerate(features)})
dmat_data = np.asarray(basis).reshape(np_features.shape[0], np_features.shape[1], -1)
dmat = shared(dmat_data)
with pm.Model() as model:
mutable_data = pm.MutableData("data", np_features)
HALFNORMAL_SCALE = 1. / np.sqrt(1. - 2. / np.pi)
mu = pm.Normal('mu_grw', 0., 1., shape=(dmat.shape[1], 1))
delta = pm.Normal('delta_grw', 0., 0.1/2.5, shape=(dmat.shape[1], dmat.shape[2]))
sigma = pm.HalfNormal('sigma_grw', 0.1 * HALFNORMAL_SCALE, shape=(dmat.shape[1], 1))
grw = pm.Deterministic('grw', mu + sigma * delta.cumsum(axis=1))
f = at.tensordot(dmat, grw)
y = pm.MutableData("y", data["y"])
eps = pm.HalfNormal("eps", sigma=1)
normal = pm.Normal("normal", mu=f, sigma=eps, observed=y)
results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
The only reshape remaining I can see in that code is in the dmat_data = np.asarray(basis).reshape(np_features.shape[0], np_features.shape[1], -1)
line, but that's on pure numpy before it's passed into aesara via shared()
so I'm not sure how that would be getting 'picked up' by jax? I presume I'm missing something here?
You can put the code snippet anywhere before you call sample. The reshape may be added when defining the logp or dlogp without ever being present in the original model.
the tensordot call is probably what is introducing the reshape operation in your graph: https://github.com/aesara-devs/aesara/blob/7393b7441601eaad98bc0cb494aa8fba2ea4bf6a/aesara/tensor/math.py#L2213
Anyways, I think that code snippet I shared should free you of the errors.
Yes I can confirm that code snippet fixes things - thanks very much @ricardoV94!
Should be fixed by https://github.com/aesara-devs/aesara/pull/1111
The fix is included in the latest aesara releases, which we are depending on since https://github.com/pymc-devs/pymc/pull/6059