pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Sampling with `sampling_jax.sample_blackjax_nuts` fails for `tune < 20`

Open bherwerth opened this issue 2 years ago • 0 comments

Please provide a minimal, self-contained, and reproducible example.

# %%
import pymc as pm
import numpy as np
from functools import partial

from pymc import sampling_jax
import pytest

rstate = np.random.default_rng(seed=0)
observed = rstate.normal(size=10)
with pm.Model() as model:
    mu = pm.Normal("mu", mu=0, sigma=10)
    sigma = pm.Gamma("sigma", alpha=1, beta=1)
    y = pm.Normal("y", mu=mu, sigma=sigma, observed=observed)

    # pm.sample(10, tune=19, cores=1) # works
    # sampling_jax.sample_numpyro_nuts(10, tune=19) # works
    # sampling_jax.sample_blackjax_nuts(10, tune=20) # works
    sampling_jax.sample_blackjax_nuts(10, tune=19) # fails

Please provide the full traceback.

Complete error traceback
Traceback (most recent call last):
  File "/Users/benedikt/code/pymc/nb_examples/number_of_samples.py", line 19, in <module>
    sampling_jax.sample_blackjax_nuts(10, tune=19) # fails
  File "/Users/benedikt/code/pymc/pymc/sampling_jax.py", line 331, in sample_blackjax_nuts
    states, _ = map_fn(get_posterior_samples)(keys, init_params)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/api.py", line 2144, in cache_miss
    out_tree, out_flat = f_pmapped_(*args, **kwargs)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/api.py", line 2020, in pmap_f
    out = pxla.xla_pmap(
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 1939, in bind
    return map_bind(self, fun, *args, **params)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 1971, in map_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 1942, in process
    return trace.process_map(self, fun, tracers, params)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 683, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 829, in xla_pmap_impl
    compiled_fun, fingerprint = parallel_callable(
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/linear_util.py", line 286, in memoized_fun
    ans = call(fun, *args)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 857, in parallel_callable
    pmap_computation = lower_parallel_callable(
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/profiler.py", line 312, in wrapper
    return func(*args, **kwargs)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1030, in lower_parallel_callable
    jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 937, in stage_parallel_callable
    jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/profiler.py", line 312, in wrapper
    return func(*args, **kwargs)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1946, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1892, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/api.py", line 522, in cache_miss
    out_flat = xla.xla_call(
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 1836, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/core.py", line 1852, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1626, in process_call
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1920, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/benedikt/code/pymc/pymc/sampling_jax.py", line 196, in _blackjax_inference_loop
    last_state, kernel, _ = adapt.run(seed, init_position)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/blackjax/kernels.py", line 577, in run
    last_state, warmup_chain = jax.lax.scan(
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 213, in scan
    raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat)))
jax._src.traceback_util.UnfilteredStackTrace: ValueError: scan got values with different leading axis sizes: 19, 19, 18.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/benedikt/code/pymc/nb_examples/number_of_samples.py", line 19, in <module>
    sampling_jax.sample_blackjax_nuts(10, tune=19) # fails
  File "/Users/benedikt/code/pymc/pymc/sampling_jax.py", line 331, in sample_blackjax_nuts
    states, _ = map_fn(get_posterior_samples)(keys, init_params)
  File "/Users/benedikt/code/pymc/pymc/sampling_jax.py", line 196, in _blackjax_inference_loop
    last_state, kernel, _ = adapt.run(seed, init_position)
  File "/Users/benedikt/miniforge3/envs/pymc-dev/lib/python3.10/site-packages/blackjax/kernels.py", line 577, in run
    last_state, warmup_chain = jax.lax.scan(
ValueError: scan got values with different leading axis sizes: 19, 19, 18.

Please provide any additional information below.

Versions and main components

  • PyMC/PyMC3 Version: commit 2583b7f834f2012bc4a462a576b956cfca722a7c
  • Aesara/Theano Version: aesara=2.7.5
  • Python Version: 3.10.5
  • Operating system: mac OS 10.15.7
  • How did you install PyMC/PyMC3: env with conda, pymc with pip from git
  • Other:
    • jax=0.3.14=pyhd8ed1ab_1
    • jaxlib=0.3.14=cpu_py310h52bc2dc_0
      • blackjax==0.8.2
      • jaxopt==0.4.3

bherwerth avatar Jul 15 '22 20:07 bherwerth