pymc
pymc copied to clipboard
BUG: Forward sampling with dims fails when mode="JAX"
Describe the issue:
Shapes aren't being correct set on variables when using coords in JAX. I guess this is a consequence of coords being mutable by default, and could be addressed by using freeze_dims_and_data as in #7263. If this is the case, perhaps we should check for the mode='JAX' compile_kwarg in forward samplers and raise early with a more informative error?
Reproduceable code example:
import pymc as pm
# Fails
with pm.Model(coords={'a':['1']}) as m:
x = pm.Normal('x', dims=['a'])
pm.sample_prior_predictive(compile_kwargs={'mode':'JAX'})
# Works
with pm.Model() as m:
x = pm.Normal('x', shape=(1,))
pm.sample_prior_predictive(compile_kwargs={'mode':'JAX'})
Error message:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:196, in streamline.<locals>.streamline_default_f()
193 for thunk, node, old_storage in zip(
194 thunks, order, post_thunk_old_storage
195 ):
--> 196 thunk()
197 for old_s in old_storage:
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
654 def thunk(
655 fgraph=self.fgraph,
656 fgraph_jit=fgraph_jit,
657 thunk_inputs=thunk_inputs,
658 thunk_outputs=thunk_outputs,
659 ):
--> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
[... skipping hidden 11 frame]
File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:5, in jax_funcified_fgraph(random_generator_shared_variable, a)
4 # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
----> 5 variable, x = sample_fn(random_generator_shared_variable, tensor_variable, tensor_constant, tensor_constant_1, tensor_constant_2)
6 return x, variable
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:106, in jax_funcify_RandomVariable.<locals>.sample_fn(rng, size, dtype, *parameters)
105 def sample_fn(rng, size, dtype, *parameters):
--> 106 return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:164, in jax_sample_fn_loc_scale.<locals>.sample_fn(rng, size, dtype, *parameters)
163 loc, scale = parameters
--> 164 sample = loc + jax_op(sampling_key, size, dtype) * scale
165 rng["jax_state"] = rng_key
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/random.py:710, in normal(key, shape, dtype)
709 dtype = dtypes.canonicalize_dtype(dtype)
--> 710 shape = core.as_named_shape(shape)
711 return _normal(key, shape, dtype)
[... skipping hidden 2 frame]
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/core.py:2142, in canonicalize_shape(shape, context)
2141 pass
-> 2142 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jax_funcified_fgraph at /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:1 for jit. This concrete value was not available in Python because it depends on the value of the argument a.
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[19], line 1
----> 1 pm.draw(x, mode='JAX')
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/sampling/forward.py:314, in draw(vars, draws, random_seed, **kwargs)
311 draw_fn = compile_pymc(inputs=[], outputs=vars, random_seed=random_seed, **kwargs)
313 if draws == 1:
--> 314 return draw_fn()
316 # Single variable output
317 if not isinstance(vars, list | tuple):
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
967 t0_fn = time.perf_counter()
968 try:
969 outputs = (
--> 970 self.vm()
971 if output_subset is None
972 else self.vm(output_subset=output_subset)
973 )
974 except Exception:
975 restore_defaults()
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:200, in streamline.<locals>.streamline_default_f()
198 old_s[0] = None
199 except Exception:
--> 200 raise_with_op(fgraph, node, thunk)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:523, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
518 warnings.warn(
519 f"{exc_type} error does not allow us to add an extra error message"
520 )
521 # Some exception need extra parameter in inputs. So forget the
522 # extra long error message in that case.
--> 523 raise exc_value.with_traceback(exc_trace)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/utils.py:196, in streamline.<locals>.streamline_default_f()
192 try:
193 for thunk, node, old_storage in zip(
194 thunks, order, post_thunk_old_storage
195 ):
--> 196 thunk()
197 for old_s in old_storage:
198 old_s[0] = None
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
654 def thunk(
655 fgraph=self.fgraph,
656 fgraph_jit=fgraph_jit,
657 thunk_inputs=thunk_inputs,
658 thunk_outputs=thunk_outputs,
659 ):
--> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
663 compute_map[o_var][0] = True
[... skipping hidden 11 frame]
File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:5, in jax_funcified_fgraph(random_generator_shared_variable, a)
3 tensor_variable = shape_tuple_fn(a)
4 # normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
----> 5 variable, x = sample_fn(random_generator_shared_variable, tensor_variable, tensor_constant, tensor_constant_1, tensor_constant_2)
6 return x, variable
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:106, in jax_funcify_RandomVariable.<locals>.sample_fn(rng, size, dtype, *parameters)
105 def sample_fn(rng, size, dtype, *parameters):
--> 106 return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pytensor/link/jax/dispatch/random.py:164, in jax_sample_fn_loc_scale.<locals>.sample_fn(rng, size, dtype, *parameters)
162 rng_key, sampling_key = jax.random.split(rng_key, 2)
163 loc, scale = parameters
--> 164 sample = loc + jax_op(sampling_key, size, dtype) * scale
165 rng["jax_state"] = rng_key
166 return (rng, sample)
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/random.py:710, in normal(key, shape, dtype)
707 raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
708 f"got {dtype}")
709 dtype = dtypes.canonicalize_dtype(dtype)
--> 710 shape = core.as_named_shape(shape)
711 return _normal(key, shape, dtype)
[... skipping hidden 2 frame]
File ~/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/jax/_src/core.py:2142, in canonicalize_shape(shape, context)
2140 except TypeError:
2141 pass
-> 2142 raise _invalid_shape_error(shape, context)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jax_funcified_fgraph at /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpjm3tlztx:1 for jit. This concrete value was not available in Python because it depends on the value of the argument a.
Apply node that caused the error: normal_rv{0, (0, 0), floatX, False}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x29A781FC0>), JAXShapeTuple.0, 11, 0, 1.0)
Toposort index: 1
Inputs types: [RandomGeneratorType, TensorType(int64, shape=(1,)), TensorType(int64, shape=()), TensorType(int8, shape=()), TensorType(float32, shape=())]
Inputs shapes: ['No shapes', ()]
Inputs strides: ['No strides', ()]
Inputs values: [{'bit_generator': 1, 'state': {'state': 5504079417979030970, 'inc': 4407794720271215875}, 'has_uint32': 0, 'uinteger': 0, 'jax_state': array([1281518353, 2620247482], dtype=uint32)}, array(1)]
Outputs clients: [['output'], ['output']]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_14410/452414321.py", line 2, in <module>
x = pm.Normal('x', dims=['a'])
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/distribution.py", line 554, in __new__
rv_out = cls.dist(*args, **kwargs)
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/continuous.py", line 511, in dist
return super().dist([mu, sigma], **kwargs)
File "/Users/jessegrabowski/mambaforge/envs/pymc-experimental/lib/python3.12/site-packages/pymc/distributions/distribution.py", line 633, in dist
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
PyMC version information:
Context for the issue:
No response
Shapes are being set correctly, it's just that as you said they are mutable and JAX simply doesn't support that, at least not without static_argnums. When we write x = pm.Normal("x", dims="trial") PyMC is writing x = pt.random.normal(size=trial_length), where trial_length is a shared scalar variable.
If this is the case, perhaps we should check for the mode='JAX' compile_kwarg in forward samplers and raise early with a more informative error?
That may be a bit cumbersome? You don't want to raise if the shapes are constant, which happens after freeze_rv_and_dims or if a user specified shape besides dims, and introspecting the graph to assess which case it is could be messy/costly
I guess I don't understand why the shared variable dims aren't replaced before forward sampling. At that point the shapes should all be known.
I guess I don't understand why the shared variable dims aren't replaced before forward sampling. At that point the shapes should all be known.
We never did that. We do it for mcmc sampling in the JAX samplers, because it's a specific code path.
We could do that, although I think the explicit freeze approach is better. I'm thinking of reintroducing caching and then it becomes very useful being able to compile a function that works for multiple dim lengths: https://github.com/pymc-devs/pymc/discussions/7177
Well it wasn't necessary before, because shapes induced by coords were fixed by default
Well it wasn't necessary before, because shapes induced by coords were fixed by default
Yes, although even earlier they were also mutable by default.
JAX backend for forward sampling is still niche use case, I wouldn't say we were officially supporting it yet.
I'm also unhappy to put too much work around JAX inflexibility.
One idea would be to go on Pytensor and be more clever about static_argnums, we could do some cheap checks to see if a variable is used directly as the shape of an Op (RVs, Alloc) and mark that variable as static_argnum when compiling the Jitted function. This only works for scalar variables, and we usually use 0d arrays, but maybe it's something we could work around.
That is a more general QoL improvement as well?
Yes, I agree that we should be using static_argnums. That would be using JAX's own work-around for the static shape problem, so the functions we produce that way would be no worse than a native JAX solution, which seems fine to me.