pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Reshape operation in logp graph not supported in JAX backend

Open markgoodhead opened this issue 2 years ago • 10 comments

Description of your problem

Please provide a minimal, self-contained, and reproducible example.

import pymc as pm
import pymc.sampling_jax
import numpy as np
import pandas as pd
from aesara import shared, tensor as at
from patsy import dmatrix

rng = np.random.default_rng(0)
size = 2_000
x1 = rng.normal(size=size)
x2 = rng.normal(size=size)
data = pd.DataFrame(
    {
        "x1": x1,
        "x2": x2,
        "y": rng.normal(loc=x1+x2, size=size)
    }
)
features = ["x1", "x2"]
DEGREES = 3
N_KNOT = 7
df = N_KNOT + DEGREES + 1
mat_str = ""
mat_str_end = " - 1"
mat_str_middle = " + "
np_features = data[features].values
for feature in features:
    mat_str += f"bs({feature}, df={df}, degree={DEGREES}){mat_str_middle}"
mat_str = mat_str[:-2] + mat_str_end
basis = dmatrix(mat_str, {feature: np_features[:, i] for i, feature in enumerate(features)})
dmat_data = np.asarray(basis).reshape(np_features.shape[0], np_features.shape[1], -1)
dmat = shared(dmat_data)
with pm.Model() as model:
    mutable_data = pm.MutableData("data", np_features)
    HALFNORMAL_SCALE = 1. / np.sqrt(1. - 2. / np.pi)
    mu = pm.Normal('mu_grw', 0., 1., shape=dmat.shape[1])
    delta = pm.Normal('delta_grw', 0., 0.1/2.5, shape=(dmat.shape[1], dmat.shape[2]))
    sigma = pm.HalfNormal('sigma_grw', 0.1 * HALFNORMAL_SCALE, shape=dmat.shape[1])
    grw = pm.Deterministic('grw', mu[:, None] + sigma[:, None] * delta.cumsum(axis=1))
    f = at.tensordot(dmat, grw)
    y = pm.MutableData("y", data["y"])
    eps = pm.HalfNormal("eps", sigma=1)
    normal = pm.Normal("normal", mu=f, sigma=eps, observed=y)
    results = pm.sample()
    #results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")

Please provide the full traceback.

Complete error traceback
With pm.sample():

Works as expected (with Ricardo's fix)

For numpyro/blackjax:

Traceback (most recent call last):
    results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 335, in sample_blackjax_nuts
    states, _ = map_fn(get_posterior_samples)(keys, init_params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 1485, in vmap_f
    out_flat = batching.batch(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 473, in cache_miss
    out_flat = xla.xla_call(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1765, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/batching.py", line 226, in process_call
    vals_out = call_primitive.bind(f_, *vals, **params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1765, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1781, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 678, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 182, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 285, in memoized_fun
    ans = call(fun, *args)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 230, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/dispatch.py", line 272, in lower_xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1893, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 199, in _blackjax_inference_loop
    last_state, kernel, _ = adapt.run(seed, init_position)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/kernels.py", line 567, in run
    init_state = algorithm.init(position, logprob_fn, logprob_grad_fn)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 78, in init
    potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 995, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/api.py", line 2457, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/ad.py", line 130, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/ad.py", line 119, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 616, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 70, in potential_fn
    return -logprob_fn(x)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 109, in logp_fn_wrap
    return logp_fn(*x)[0]
  File "/tmp/tmpwzpirvop", line 32, in jax_funcified_fgraph
    auto_129685 = reshape(auto_131175, auto_130149)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/aesara/link/jax/dispatch.py", line 731, in reshape
    return jnp.reshape(x, shape)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 727, in reshape
    return a.reshape(newshape, order=order)  # forward to method for ndarrays
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 745, in _reshape
    newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 738, in _compute_newshape
    newshape = core.canonicalize_shape(newshape if iterable else [newshape])
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/core.py", line 1651, in canonicalize_shape
    raise _invalid_shape_error(shape, context)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got [22].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
    results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 335, in sample_blackjax_nuts
    states, _ = map_fn(get_posterior_samples)(keys, init_params)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 199, in _blackjax_inference_loop
    last_state, kernel, _ = adapt.run(seed, init_position)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/kernels.py", line 567, in run
    init_state = algorithm.init(position, logprob_fn, logprob_grad_fn)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 78, in init
    potential_energy, potential_energy_grad = jax.value_and_grad(potential_fn)(
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/blackjax/mcmc/hmc.py", line 70, in potential_fn
    return -logprob_fn(x)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/pymc/sampling_jax.py", line 109, in logp_fn_wrap
    return logp_fn(*x)[0]
  File "/tmp/tmpwzpirvop", line 32, in jax_funcified_fgraph
    auto_129685 = reshape(auto_131175, auto_130149)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/aesara/link/jax/dispatch.py", line 731, in reshape
    return jnp.reshape(x, shape)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 727, in reshape
    return a.reshape(newshape, order=order)  # forward to method for ndarrays
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 745, in _reshape
    newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
  File "/home/mark/anaconda3/envs/mark1/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 738, in _compute_newshape
    newshape = core.canonicalize_shape(newshape if iterable else [newshape])
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [22].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

Please provide any additional information below.

Versions and main components

  • PyMC/PyMC3 Version: 4.0.1
  • Aesara/Theano Version: 2.7.3
  • Python Version: 3.9
  • Operating system: Linux
  • How did you install PyMC/PyMC3: (conda/pip) pip

markgoodhead avatar Jun 25 '22 16:06 markgoodhead

This should be fixed in version 4.0.1 which we released a couple of days ago. Mind giving it a try?

ricardoV94 avatar Jun 25 '22 17:06 ricardoV94

Just upgraded and I got the same errors as before for both pm.sample() and the jax samplers

markgoodhead avatar Jun 25 '22 18:06 markgoodhead

Your model is misspecified, namely the sigma=eps is allowed to be negative which leads to 0 probability when sampling starts. You can change eps prior to a positive-only distribution like HalfNormal, or transform it through pm.math.exp for example.

ricardoV94 avatar Jun 25 '22 18:06 ricardoV94

Ah apologies, that's a bad modification from my real model to the reproduction script - thanks! With that change (original post edited) pm.sample() works fine and the jax samplers give the same error

markgoodhead avatar Jun 25 '22 18:06 markgoodhead

Seems like your model is using reshape at some point which can't be safely jitted. This is a common issue with the JAX backend. D you get the same error with sampe_numpyro_nuts. I expect so.

ricardoV94 avatar Jun 25 '22 19:06 ricardoV94

We should probably patch this upstream, but in the meantime you can try to run this snippet before you sample with JAX:

import jax
from aesara.graph import Constant
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.shape import Reshape


@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, node, **kwargs):

    shape = node.inputs[1]
    if isinstance(shape, Constant):
        constant_shape = shape.data
        def reshape(x, _):
            return jax.numpy.reshape(x, constant_shape)

    else:  
        def reshape(x, shape):        
            return jax.numpy.reshape(x, shape)

    return reshape 

ricardoV94 avatar Jun 25 '22 19:06 ricardoV94

We should probably patch this upstream, but in the meantime you can try to run this snippet before you sample with JAX:

import jax
from aesara.graph import Constant
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.shape import Reshape


@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, node, **kwargs):

    shape = node.inputs[1]
    if isinstance(shape, Constant):
        constant_shape = shape.data
        def reshape(x, _):
            return jax.numpy.reshape(x, constant_shape)

    else:  
        def reshape(x, shape):        
            return jax.numpy.reshape(x, shape)

    return reshape 

Ah thanks! Which part of the model should I run that on? I tried re-writing the model to remove what I thought would be the implicit reshapes but got the same error:

import pymc as pm
import pymc.sampling_jax
import numpy as np
import pandas as pd
from aesara import shared, tensor as at
from patsy import dmatrix

rng = np.random.default_rng(0)
size = 2_000
x1 = rng.normal(size=size)
x2 = rng.normal(size=size)
data = pd.DataFrame(
    {
        "x1": x1,
        "x2": x2,
        "y": rng.normal(loc=x1+x2, size=size)
    }
)
features = ["x1", "x2"]
DEGREES = 3
N_KNOT = 7
df = N_KNOT + DEGREES + 1
mat_str = ""
mat_str_end = " - 1"
mat_str_middle = " + "
np_features = data[features].values
for feature in features:
    mat_str += f"bs({feature}, df={df}, degree={DEGREES}){mat_str_middle}"
mat_str = mat_str[:-2] + mat_str_end
basis = dmatrix(mat_str, {feature: np_features[:, i] for i, feature in enumerate(features)})
dmat_data = np.asarray(basis).reshape(np_features.shape[0], np_features.shape[1], -1)
dmat = shared(dmat_data)
with pm.Model() as model:
    mutable_data = pm.MutableData("data", np_features)
    HALFNORMAL_SCALE = 1. / np.sqrt(1. - 2. / np.pi)
    mu = pm.Normal('mu_grw', 0., 1., shape=(dmat.shape[1], 1))
    delta = pm.Normal('delta_grw', 0., 0.1/2.5, shape=(dmat.shape[1], dmat.shape[2]))
    sigma = pm.HalfNormal('sigma_grw', 0.1 * HALFNORMAL_SCALE, shape=(dmat.shape[1], 1))
    grw = pm.Deterministic('grw', mu + sigma * delta.cumsum(axis=1))
    f = at.tensordot(dmat, grw)
    y = pm.MutableData("y", data["y"])
    eps = pm.HalfNormal("eps", sigma=1)
    normal = pm.Normal("normal", mu=f, sigma=eps, observed=y)
    results = pm.sampling_jax.sample_blackjax_nuts(chain_method="vectorized")

The only reshape remaining I can see in that code is in the dmat_data = np.asarray(basis).reshape(np_features.shape[0], np_features.shape[1], -1) line, but that's on pure numpy before it's passed into aesara via shared() so I'm not sure how that would be getting 'picked up' by jax? I presume I'm missing something here?

markgoodhead avatar Jun 25 '22 19:06 markgoodhead

You can put the code snippet anywhere before you call sample. The reshape may be added when defining the logp or dlogp without ever being present in the original model.

ricardoV94 avatar Jun 25 '22 20:06 ricardoV94

the tensordot call is probably what is introducing the reshape operation in your graph: https://github.com/aesara-devs/aesara/blob/7393b7441601eaad98bc0cb494aa8fba2ea4bf6a/aesara/tensor/math.py#L2213

Anyways, I think that code snippet I shared should free you of the errors.

ricardoV94 avatar Jun 26 '22 05:06 ricardoV94

Yes I can confirm that code snippet fixes things - thanks very much @ricardoV94!

markgoodhead avatar Jun 26 '22 08:06 markgoodhead

Should be fixed by https://github.com/aesara-devs/aesara/pull/1111

ricardoV94 avatar Aug 12 '22 10:08 ricardoV94

The fix is included in the latest aesara releases, which we are depending on since https://github.com/pymc-devs/pymc/pull/6059

ricardoV94 avatar Aug 25 '22 09:08 ricardoV94