pymc
pymc copied to clipboard
BUG: Regression in jax translation from 5.12 -> 5.13
Describe the issue:
I have a few models where I have to do some rather complex tensor manipulation, and moving from 5.12 to 5.13 quite a few of them broke down with JAX errors.
As the models themselves are big and unwieldy, I have tried to re-create the same issue with a toy example. As you can see, it needs to be quite convoluted to illicit the error (requiring a model dimension, a call to pt.concatenate and pt.set_subtensor), but I do run into it with more complex actual use cases as well.
I have managed to work around it i some cases by avoiding pt.concatenate and instead just creating an empty tensor and setting it's parts via set_subtensor, but I have one model where even that runs into issues. So it would be very nice if it worked like it used to before :)
The facts of the case:
- Toy example works with 5.12
- Toy example fails with 5.13.1
- Toy example works if using normal sampler instead of numpyro_nuts
Reproduceable code example:
import pymc as pm
from pymc.sampling import jax as pm_jax
import pytensor.tensor as pt
import numpy as np
obs = np.array([
[1,0,1,0,1,0],
[0,1,1,0,1,0],
])
ns = [3,3]
with pm.Model() as model:
model.add_coord('mw',range(6))
odds = pt.zeros( (len(ns),model.dim_lengths['mw']) )
modds = pm.Normal('N',shape=(len(ns),model.dim_lengths['mw']//2 - 1))
modds = pt.concatenate([pt.ones_like(modds[:,:1]),modds[:,:]],axis=1)
odds = pt.set_subtensor(odds[:,[0,2,4]],modds)
pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
pm_jax.sample_numpyro_nuts()
Error message:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/home/velochy/salk/salk_internal_package/experiments.ipynb Cell 1 line 2
20 #odds = pt.set_subtensor(odds[:,[5,3,1]],modds)
22 pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
---> 24 pm_jax.sample_numpyro_nuts()
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:567, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
564 raise ValueError(f"{nuts_sampler=} not recognized")
566 tic1 = datetime.now()
--> 567 raw_mcmc_samples, sample_stats, library = sampler_fn(
568 model=model,
569 target_accept=target_accept,
570 tune=tune,
571 draws=draws,
572 chains=chains,
573 chain_method=chain_method,
574 progressbar=progressbar,
575 random_seed=random_seed,
576 initial_points=initial_points,
577 nuts_kwargs=nuts_kwargs,
578 )
579 tic2 = datetime.now()
581 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:484, in _sample_numpyro_nuts(model, target_accept, tune, draws, chains, chain_method, progressbar, random_seed, initial_points, nuts_kwargs)
481 if chains > 1:
482 map_seed = jax.random.split(map_seed, chains)
--> 484 pmap_numpyro.run(
485 map_seed,
486 init_params=initial_points,
487 extra_fields=(
488 "num_steps",
489 "potential_energy",
490 "energy",
491 "adapt_state.step_size",
492 "accept_prob",
493 "diverging",
494 ),
495 )
497 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
498 sample_stats = _numpyro_stats_to_dict(pmap_numpyro)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:650, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
648 states, last_state = _laxmap(partial_map_fn, map_args)
649 elif self.chain_method == "parallel":
--> 650 states, last_state = pmap(partial_map_fn)(map_args)
651 else:
652 assert self.chain_method == "vectorized"
[... skipping hidden 12 frame]
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:426, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
424 # Check if _sample_fn is None, then we need to initialize the sampler.
425 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 426 new_init_state = self.sampler.init(
427 rng_key,
428 self.num_warmup,
429 init_params,
430 model_args=args,
431 model_kwargs=kwargs,
432 )
433 init_state = new_init_state if init_state is None else init_state
434 sample_fn, postprocess_fn = self._get_cached_fns()
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:783, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
763 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
764 init_params,
765 num_warmup=num_warmup,
(...)
780 rng_key=rng_key,
781 )
782 if is_prng_key(rng_key):
--> 783 init_state = hmc_init_fn(init_params, rng_key)
784 else:
785 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
786 # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
787 # wa_steps because those variables do not depend on traced args: init_params, rng_key.
788 init_state = vmap(hmc_init_fn)(init_params, rng_key)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:763, in HMC.init.<locals>.<lambda>(init_params, rng_key)
760 dense_mass = [tuple(sorted(z))] if dense_mass else []
761 assert isinstance(dense_mass, list)
--> 763 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
764 init_params,
765 num_warmup=num_warmup,
766 step_size=self._step_size,
767 num_steps=self._num_steps,
768 inverse_mass_matrix=inverse_mass_matrix,
769 adapt_step_size=self._adapt_step_size,
770 adapt_mass_matrix=self._adapt_mass_matrix,
771 dense_mass=dense_mass,
772 target_accept_prob=self._target_accept_prob,
773 trajectory_length=self._trajectory_length,
774 max_tree_depth=self._max_tree_depth,
775 find_heuristic_step_size=self._find_heuristic_step_size,
776 forward_mode_differentiation=self._forward_mode_differentiation,
777 regularize_mass_matrix=self._regularize_mass_matrix,
778 model_args=model_args,
779 model_kwargs=model_kwargs,
780 rng_key=rng_key,
781 )
782 if is_prng_key(rng_key):
783 init_state = hmc_init_fn(init_params, rng_key)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc.py:336, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
334 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
335 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 336 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
337 energy = vv_state.potential_energy + kinetic_fn(
338 wa_state.inverse_mass_matrix, vv_state.r
339 )
340 zero_int = jnp.array(0, dtype=jnp.result_type(int))
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:282, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
274 """
275 :param z: Position of the particle.
276 :param r: Momentum of the particle.
(...)
279 :return: initial state for the integrator.
280 """
281 if potential_energy is None or z_grad is None:
--> 282 potential_energy, z_grad = _value_and_grad(
283 potential_fn, z, forward_mode_differentiation
284 )
285 return IntegratorState(z, r, potential_energy, z_grad)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/hmc_util.py:250, in _value_and_grad(f, x, forward_mode_differentiation)
248 return out, grads
249 else:
--> 250 return value_and_grad(f, has_aux=False)(x)
[... skipping hidden 8 frame]
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/jax.py:156, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
155 def logp_fn_wrap(x):
--> 156 return logp_fn(*x)[0]
File /tmp/tmpfrmeiqr6:11, in jax_funcified_fgraph(N)
9 tensor_variable_3 = elemwise_fn_2(tensor_variable_2, tensor_constant_1)
10 # Alloc([[1.]], 2, Sub.0)
---> 11 tensor_variable_4 = alloc(tensor_constant_2, tensor_constant_3, tensor_variable_3)
12 # Join(1, Alloc.0, N)
13 tensor_variable_5 = join(tensor_constant_4, tensor_variable_4, N)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pytensor/link/jax/dispatch/tensor_basic.py:47, in jax_funcify_Alloc.<locals>.alloc(x, *shape)
46 def alloc(x, *shape):
---> 47 res = jnp.broadcast_to(x, shape)
48 Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
49 return res
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:1222, in broadcast_to(array, shape)
1218 @util.implements(np.broadcast_to, lax_description="""\
1219 The JAX version does not necessarily return a view of the input.
1220 """)
1221 def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array:
-> 1222 return util._broadcast_to(array, shape)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/numpy/util.py:417, in _broadcast_to(arr, shape)
415 shape = (shape,)
416 # check that shape is concrete
--> 417 shape = core.canonicalize_shape(shape) # type: ignore[arg-type]
418 arr_shape = np.shape(arr)
419 if core.definitely_equal_shape(arr_shape, shape):
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/jax/_src/core.py:2117, in canonicalize_shape(shape, context)
2115 except TypeError:
2116 pass
-> 2117 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (2, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function _single_chain_mcmc at /home/velochy/anaconda3/envs/salk/lib/python3.12/site-packages/numpyro/infer/mcmc.py:422 for pmap. This value became a tracer due to JAX operations on these lines:
operation a:bool[] = lt b c
from line /tmp/tmpfrmeiqr6:5:24 (jax_funcified_fgraph)
operation a:i64[] = pjit[
name=_where
jaxpr={ lambda ; b:bool[] c:i64[] d:i64[]. let
e:i64[] = select_n b d c
in (e,) }
] f g h
from line /tmp/tmpfrmeiqr6:7:24 (jax_funcified_fgraph)
operation a:i64[] = sub b c
from line /tmp/tmpfrmeiqr6:9:24 (jax_funcified_fgraph)
PyMC version information:
Fails on 5.13.1
Context for the issue:
No response
Does using this helper first, fix the problem? https://github.com/pymc-devs/pymc/blob/4bc84391893f1face230ed64241a339d4d9dbf62/pymc/model/transform/optimization.py#L23
How would one use it?
with pm.Model() as model:
...
frozen_model = freeze_dims_and_data(model)
with frozen_model:
pm.sample(nuts_sampler="numpyro")
Yes that works. Both for the toy example as well as the real model. Any downsides to doing that?
Yes that works. Both for the toy example as well as the real model. Any downsides to doing that?
No
Well there seems to be one. It is now throwing errors if I add initvals to the models. Any workarounds for that?
Can you provide a minimum working example?
import pymc as pm
from pymc.sampling import jax as pm_jax
import pytensor.tensor as pt
import numpy as np
from pymc.model.transform.optimization import freeze_dims_and_data
obs = np.array([
[1,0,1,0,1,0],
[0,1,1,0,1,0],
])
ns = [3,3]
with pm.Model() as model:
model.add_coord('mw',range(6))
odds = pt.zeros( (len(ns),model.dim_lengths['mw']) )
modds = pm.Normal('N',shape=(len(ns),model.dim_lengths['mw']//2 - 1),initval=obs[:,:2])
modds = pt.concatenate([pt.ones_like(modds[:,:1]),modds[:,:]],axis=1)
odds = pt.set_subtensor(odds[:,[0,2,4]],modds)
#odds = pt.set_subtensor(odds[:,[5,3,1]],modds)
pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
frozen_model = freeze_dims_and_data(model)
with frozen_model:
idata = pm_jax.sample_numpyro_nuts()
throws
NotImplementedError Traceback (most recent call last)
[/home/velochy/salk/salk_internal_package/experiments.ipynb](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/experiments.ipynb) Cell 1 line 2
[21](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=20) #odds = pt.set_subtensor(odds[:,[5,3,1]],modds)
[23](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=22) pm.Multinomial('ov',p=pm.math.softmax(odds), n=ns, observed = obs)
---> [25](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=24) frozen_model = freeze_dims_and_data(model)
[26](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=25) with frozen_model:
[27](vscode-notebook-cell:/home/velochy/salk/salk_internal_package/experiments.ipynb#W0sZmlsZQ%3D%3D?line=26) idata = pm_jax.sample_numpyro_nuts()
File [~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:34](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:34), in freeze_dims_and_data(model)
[23](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:23) def freeze_dims_and_data(model: Model) -> Model:
[24](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:24) """Recreate a Model with fixed RV dimensions and Data values.
[25](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:25)
[26](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:26) The dimensions of the pre-existing RVs will no longer follow changes to the coordinates.
(...)
[32](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:32) are more restrictive about dynamic shapes such as JAX.
[33](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:33) """
---> [34](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:34) fg, memo = fgraph_from_model(model)
[36](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:36) # Replace mutable dim lengths and data by constants
[37](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:37) frozen_vars = {
[38](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:38) memo[dim_length]: constant(
[39](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:39) dim_length.get_value(), name=dim_length.name, dtype=dim_length.type.dtype
(...)
[42](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:42) if isinstance(dim_length, SharedVariable)
[43](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/transform/optimization.py:43) }
File [~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:154](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:154), in fgraph_from_model(model, inlined_views)
[132](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:132) """Convert Model to FunctionGraph.
[133](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:133)
[134](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:134) See: model_from_fgraph
(...)
[150](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:150) A dictionary mapping original model variables to the equivalent nodes in the fgraph.
[151](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:151) """
[153](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:153) if any(v is not None for v in model.rvs_to_initial_values.values()):
--> [154](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:154) raise NotImplementedError("Cannot convert models with non-default initial_values")
[156](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:156) if model.parent is not None:
[157](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:157) raise ValueError(
[158](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:158) "Nested sub-models cannot be converted to fgraph. Convert the parent model instead"
[159](https://file+.vscode-resource.vscode-cdn.net/home/velochy/salk/salk_internal_package/~/anaconda3/envs/salk/lib/python3.12/site-packages/pymc/model/fgraph.py:159) )
NotImplementedError: Cannot convert models with non-default initial_values
Right we don't support custom initial values on the model transformations. You should be able to specify them when calling pm.sample
instead.
Or specify them after freezing the model, with model.set_initval
or something
ok. Maybe it makes sense to deprecate initval parameter on the RVs then?
I think there was resistance to that, but it would be my preference
Um. Resistance to removing something that no longer works? Or do I misunderstand something?
It works, just not for model transformations like freeze_dims_and_data
That's why it's a NotImplementedError
, not a ValueError
or something like that.
and jax sampler only works for frozen dims if we want more complex pytensor manipulations? So by implication jax sampler only works for complex models if you dont use initval with RVs. Is this by design?
Most models should sample fine in JAX without frozen dims, but we expect some hiccups like the one you found. Hence why that helper was added. That helper does not work with custom initvals, but you can pass custom initvals directly to sampler anyway.
It's not a final solution, but everyone should be able to do their thing right now.
Your model wouldn't have worked before the changes with mutable dims anyway.
You can change this line:
pt.zeros( (len(ns),model.dim_lengths['mw']) )
To:
pt.zeros( (len(ns), len(model.coords['mw']))
That will freeze the second dimension, instead of linking it to the mutable mw
dim_length, which is the change from 5.12 to 5.13 that's hitting you.
Ok. I guess I have my answers, and you are right, I have all the tools needed to make it work. Thank you for your thorough answers @ricardoV94
The out of sample model pattern breaks in a JAX workflow due to the issues reported here. For example:
import pymc as pm
from pymc.model.transform.optimization import freeze_dims_and_data
with pm.Model(coords={'a':[1]}) as m:
mu = pm.Normal('mu', dims=['a'])
x = pm.Normal('x', mu=mu, sigma=1, dims=['a'])
idata = pm.sample(nuts_sampler='numpyro')
with pm.Model(coords={'a':[1]}) as new_m:
mu = pm.Flat('mu', dims=['a'])
x = pm.Normal('x', mu, sigma=2, dims=['a'])
frozen_new_m = freeze_dims_and_data(new_m)
with frozen_new_m:
idata_pred = pm.sample_posterior_predictive(idata, var_names=['x'],
predictions=True,
compile_kwargs={'mode':'JAX'})
I guess there's some automatic initivals being silently set when we use pm.sample_posterior_predictive
in this way, which unexpectedly breaks freeze_dims_and_data
.