pymc
pymc copied to clipboard
Sampling with `sampling_jax.sample_blackjax_nuts` fails for `tune < 20`
Please provide a minimal, self-contained, and reproducible example.
# %%
import pymc as pm
import numpy as np
from functools import partial
from pymc import sampling_jax
import pytest
rstate = np.random.default_rng(seed=0)
observed = rstate.normal(size=10)
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=10)
sigma = pm.Gamma("sigma", alpha=1, beta=1)
y = pm.Normal("y", mu=mu, sigma=sigma, observed=observed)
# pm.sample(10, tune=19, cores=1) # works
# sampling_jax.sample_numpyro_nuts(10, tune=19) # works
# sampling_jax.sample_blackjax_nuts(10, tune=20) # works
sampling_jax.sample_blackjax_nuts(10, tune=19) # fails
Please provide the full traceback.
Complete error traceback
Traceback (most recent call last):
File "/Users/benedikt/code/pymc/nb_examples/number_of_samples.py", line 19, in <module>
sampling_jax.sample_blackjax_nuts(10, tune=19) # fails
File "/Users/benedikt/code/pymc/pymc/sampling_jax.py", line 331, in sample_blackjax_nuts
states, _ = map_fn(get_posterior_samples)(keys, init_params)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/api.py", line 2144, in cache_miss
out_tree, out_flat = f_pmapped_(*args, **kwargs)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/api.py", line 2020, in pmap_f
out = pxla.xla_pmap(
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 1939, in bind
return map_bind(self, fun, *args, **params)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 1971, in map_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 1942, in process
return trace.process_map(self, fun, tracers, params)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 683, in process_call
return primitive.impl(f, *tracers, **params)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 829, in xla_pmap_impl
compiled_fun, fingerprint = parallel_callable(
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/linear_util.py", line 286, in memoized_fun
ans = call(fun, *args)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 857, in parallel_callable
pmap_computation = lower_parallel_callable(
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/profiler.py", line 312, in wrapper
return func(*args, **kwargs)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1030, in lower_parallel_callable
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 937, in stage_parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/profiler.py", line 312, in wrapper
return func(*args, **kwargs)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1946, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1892, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/api.py", line 522, in cache_miss
out_flat = xla.xla_call(
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 1836, in bind
return call_bind(self, fun, *args, **params)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 1852, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1626, in process_call
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1920, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/benedikt/code/pymc/pymc/sampling_jax.py", line 196, in _blackjax_inference_loop
last_state, kernel, _ = adapt.run(seed, init_position)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/blackjax/kernels.py", line 577, in run
last_state, warmup_chain = jax.lax.scan(
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 213, in scan
raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
jax._src.traceback_util.UnfilteredStackTrace: ValueError: scan got values with different leading axis sizes: 19, 19, 18.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/benedikt/code/pymc/nb_examples/number_of_samples.py", line 19, in <module>
sampling_jax.sample_blackjax_nuts(10, tune=19) # fails
File "/Users/benedikt/code/pymc/pymc/sampling_jax.py", line 331, in sample_blackjax_nuts
states, _ = map_fn(get_posterior_samples)(keys, init_params)
File "/Users/benedikt/code/pymc/pymc/sampling_jax.py", line 196, in _blackjax_inference_loop
last_state, kernel, _ = adapt.run(seed, init_position)
File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/blackjax/kernels.py", line 577, in run
last_state, warmup_chain = jax.lax.scan(
ValueError: scan got values with different leading axis sizes: 19, 19, 18.
Please provide any additional information below.
Versions and main components
- PyMC/PyMC3 Version: commit 2583b7f834f2012bc4a462a576b956cfca722a7c
- Aesara/Theano Version:
aesara=2.7.5
- Python Version:
3.10.5
- Operating system:
mac OS 10.15.7
- How did you install PyMC/PyMC3: env with conda, pymc with pip from git
- Other:
- jax=0.3.14=pyhd8ed1ab_1
- jaxlib=0.3.14=cpu_py310h52bc2dc_0
- blackjax==0.8.2
- jaxopt==0.4.3