pymc
pymc copied to clipboard
JAX backend fails for latent scan variables
Describe the issue:
Not sure if this belongs here or in the pytensor repo. Putting it here because the minimal example I can come up with uses PyMC. If you make a scan variable, register it without observations, then use it for further computation, the graph will fail to compile.
Reproduceable code example:
import numpy as np
import pymc as pm
import pytensor
from pytensor.compile.mode import get_mode
from pymc.pytensorf import collect_default_updates
true_sigma = 0.1
true_eta = 0.25
# GRW with observation noise:
test_mu = np.random.normal(scale=true_sigma, size=100).cumsum()
test_obs = np.random.normal(loc=test_mu, scale=true_eta)
with pm.Model() as model:
x0 = pm.Normal('x0')
sigma = pm.HalfNormal('sigma')
eta = pm.HalfNormal('eta')
def step(*args):
last_x, sigma = args
x = pm.Normal.dist(mu=last_x, sigma=sigma)
return x, collect_default_updates(args, [x])
traj, updates = pytensor.scan(step,
outputs_info=[x0],
non_sequences=[sigma],
n_steps=100,
mode=get_mode('JAX'))
model.register_rv(traj, name='traj', initval='prior')
obs = pm.Normal('obs', mu=traj, sigma=eta, observed=test_obs)
idata = pm.sample(nuts_sampler='numpyro')
Error message:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/jax/_src/api_util.py:563, in shaped_abstractify(x)
562 try:
--> 563 return _shaped_abstractify_handlers[type(x)](x)
564 except KeyError:
KeyError: <class 'numpy.random._generator.Generator'>
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:216, in streamline.<locals>.streamline_nice_errors_f()
215 for thunk, node in zip(thunks, order):
--> 216 thunk()
217 except Exception:
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
663 def thunk(
664 fgraph=self.fgraph,
665 fgraph_jit=fgraph_jit,
666 thunk_inputs=thunk_inputs,
667 thunk_outputs=thunk_outputs,
668 ):
--> 669 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
671 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
[... skipping hidden 6 frame]
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/jax/_src/api_util.py:554, in _shaped_abstractify_slow(x)
553 else:
--> 554 raise TypeError(
555 f"Cannot interpret value of type {type(x)} as an abstract array; it "
556 "does not have a dtype attribute")
557 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
558 named_shape=named_shape)
TypeError: Cannot interpret value of type <class 'numpy.random._generator.Generator'> as an abstract array; it does not have a dtype attribute
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/vm.py:414, in Loop.__call__(self)
411 for thunk, node, old_storage in zip_longest(
412 self.thunks, self.nodes, self.post_thunk_clear, fillvalue=()
413 ):
--> 414 thunk()
415 for old_s in old_storage:
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/op.py:1657, in Scan.make_thunk.<locals>.rval(p, i, o, n, allow_gc)
1654 def rval(
1655 p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
1656 ):
-> 1657 r = p(n, [x[0] for x in i], o)
1658 for o in node.outputs:
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/op.py:1918, in Scan.perform(self, node, inputs, output_storage, params)
1917 try:
-> 1918 vm()
1919 except Exception:
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:218, in streamline.<locals>.streamline_nice_errors_f()
217 except Exception:
--> 218 raise_with_op(fgraph, node, thunk)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
533 # Some exception need extra parameter in inputs. So forget the
534 # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:216, in streamline.<locals>.streamline_nice_errors_f()
215 for thunk, node in zip(thunks, order):
--> 216 thunk()
217 except Exception:
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
663 def thunk(
664 fgraph=self.fgraph,
665 fgraph_jit=fgraph_jit,
666 thunk_inputs=thunk_inputs,
667 thunk_outputs=thunk_outputs,
668 ):
--> 669 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
671 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
[... skipping hidden 6 frame]
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/jax/_src/api_util.py:554, in _shaped_abstractify_slow(x)
553 else:
--> 554 raise TypeError(
555 f"Cannot interpret value of type {type(x)} as an abstract array; it "
556 "does not have a dtype attribute")
557 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
558 named_shape=named_shape)
TypeError: Cannot interpret value of type <class 'numpy.random._generator.Generator'> as an abstract array; it does not have a dtype attribute
Apply node that caused the error: normal_rv{0, (0, 0), floatX, False}(*1-<RandomGeneratorType>, TensorConstant{[]}, TensorConstant{11}, *0-<TensorType(float64, ())>, *2-<TensorType(float64, ())>)
Toposort index: 0
Inputs types: [RandomGeneratorType, TensorType(int64, (0,)), TensorType(int64, ()), TensorType(float64, ()), TensorType(float64, ())]
Inputs shapes: [(), 'No shapes', ()]
Inputs strides: [(), 'No strides', ()]
Inputs values: [array(0.), Generator(PCG64) at 0x17E245FC0, array(1.)]
Outputs clients: [['output'], ['output']]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_32164/1022036140.py", line 22, in <module>
traj, updates = pytensor.scan(step,
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/basic.py", line 852, in scan
raw_inner_outputs = fn(*args)
File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_32164/1022036140.py", line 19, in step
x = pm.Normal.dist(mu=last_x, sigma=sigma)
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/distributions/continuous.py", line 520, in dist
return super().dist([mu, sigma], **kwargs)
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 389, in dist
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[2], line 30
28 model.register_rv(traj, name='traj', initval='prior')
29 obs = pm.Normal('obs', mu=traj, sigma=eta, observed=test_obs)
---> 30 idata = pm.sample(nuts_sampler='numpyro')
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/sampling/mcmc.py:564, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
561 auto_nuts_init = False
563 initial_points = None
--> 564 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
566 if nuts_sampler != "pymc":
567 if not isinstance(step, NUTS):
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/sampling/mcmc.py:203, in assign_step_methods(model, step, methods, step_kwargs)
195 selected = max(
196 methods,
197 key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence(
198 var, has_gradient
199 ),
200 )
201 selected_steps[selected].append(var)
--> 203 return instantiate_steppers(model, steps, selected_steps, step_kwargs)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/sampling/mcmc.py:116, in instantiate_steppers(model, steps, selected_steps, step_kwargs)
114 args = step_kwargs.get(step_class.name, {})
115 used_keys.add(step_class.name)
--> 116 step = step_class(vars=vars, model=model, **args)
117 steps.append(step)
119 unused_args = set(step_kwargs).difference(used_keys)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/step_methods/hmc/nuts.py:180, in NUTS.__init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
122 def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs):
123 r"""Set up the No-U-Turn sampler.
124
125 Parameters
(...)
178 `pm.sample` to the desired number of tuning steps.
179 """
--> 180 super().__init__(vars, **kwargs)
182 self.max_treedepth = max_treedepth
183 self.early_max_treedepth = early_max_treedepth
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/step_methods/hmc/base_hmc.py:109, in BaseHMC.__init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **pytensor_kwargs)
107 else:
108 vars = get_value_vars_from_user_vars(vars, self._model)
--> 109 super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **pytensor_kwargs)
111 self.adapt_step_size = adapt_step_size
112 self.Emax = Emax
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/step_methods/arraystep.py:164, in GradientSharedStep.__init__(self, vars, model, blocked, dtype, logp_dlogp_func, **pytensor_kwargs)
161 model = modelcontext(model)
163 if logp_dlogp_func is None:
--> 164 func = model.logp_dlogp_function(vars, dtype=dtype, **pytensor_kwargs)
165 else:
166 func = logp_dlogp_func
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/model.py:649, in Model.logp_dlogp_function(self, grad_vars, tempered, **kwargs)
646 costs = [self.logp()]
648 input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
--> 649 ip = self.initial_point(0)
650 extra_vars_and_values = {
651 var: ip[var.name]
652 for var in self.value_vars
653 if var in input_vars and var not in grad_vars
654 }
655 return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/model.py:1133, in Model.initial_point(self, random_seed)
1120 """Computes the initial point of the model.
1121
1122 Parameters
(...)
1130 Maps names of transformed variables to numeric initial values in the transformed space.
1131 """
1132 fn = make_initial_point_fn(model=self, return_transformed=True)
-> 1133 return Point(fn(random_seed), model=self)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/initial_point.py:169, in make_initial_point_fn.<locals>.make_seeded_function.<locals>.inner(seed, *args, **kwargs)
166 @functools.wraps(func)
167 def inner(seed, *args, **kwargs):
168 reseed_rngs(rngs, seed)
--> 169 values = func(*args, **kwargs)
170 return dict(zip(varnames, values))
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
967 t0_fn = time.perf_counter()
968 try:
969 outputs = (
--> 970 self.vm()
971 if output_subset is None
972 else self.vm(output_subset=output_subset)
973 )
974 except Exception:
975 restore_defaults()
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/vm.py:418, in Loop.__call__(self)
416 old_s[0] = None
417 except Exception:
--> 418 raise_with_op(self.fgraph, node, thunk)
420 return self.perform_updates()
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
530 warnings.warn(
531 f"{exc_type} error does not allow us to add an extra error message"
532 )
533 # Some exception need extra parameter in inputs. So forget the
534 # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/vm.py:414, in Loop.__call__(self)
410 try:
411 for thunk, node, old_storage in zip_longest(
412 self.thunks, self.nodes, self.post_thunk_clear, fillvalue=()
413 ):
--> 414 thunk()
415 for old_s in old_storage:
416 old_s[0] = None
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/op.py:1657, in Scan.make_thunk.<locals>.rval(p, i, o, n, allow_gc)
1654 def rval(
1655 p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
1656 ):
-> 1657 r = p(n, [x[0] for x in i], o)
1658 for o in node.outputs:
1659 compute_map[o][0] = True
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/op.py:1918, in Scan.perform(self, node, inputs, output_storage, params)
1915 t0_fn = time.perf_counter()
1917 try:
-> 1918 vm()
1919 except Exception:
1920 if hasattr(vm, "position_of_error"):
1921 # this is a new vm-provided function or c linker
1922 # they need this because the exception manipulation
1923 # done by raise_with_op is not implemented in C.
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:218, in streamline.<locals>.streamline_nice_errors_f()
216 thunk()
217 except Exception:
--> 218 raise_with_op(fgraph, node, thunk)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
530 warnings.warn(
531 f"{exc_type} error does not allow us to add an extra error message"
532 )
533 # Some exception need extra parameter in inputs. So forget the
534 # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:216, in streamline.<locals>.streamline_nice_errors_f()
214 try:
215 for thunk, node in zip(thunks, order):
--> 216 thunk()
217 except Exception:
218 raise_with_op(fgraph, node, thunk)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
663 def thunk(
664 fgraph=self.fgraph,
665 fgraph_jit=fgraph_jit,
666 thunk_inputs=thunk_inputs,
667 thunk_outputs=thunk_outputs,
668 ):
--> 669 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
671 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
672 compute_map[o_var][0] = True
[... skipping hidden 6 frame]
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/jax/_src/api_util.py:554, in _shaped_abstractify_slow(x)
552 dtype = dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True)
553 else:
--> 554 raise TypeError(
555 f"Cannot interpret value of type {type(x)} as an abstract array; it "
556 "does not have a dtype attribute")
557 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
558 named_shape=named_shape)
TypeError: Cannot interpret value of type <class 'numpy.random._generator.Generator'> as an abstract array; it does not have a dtype attribute
Apply node that caused the error: normal_rv{0, (0, 0), floatX, False}(*1-<RandomGeneratorType>, TensorConstant{[]}, TensorConstant{11}, *0-<TensorType(float64, ())>, *2-<TensorType(float64, ())>)
Toposort index: 0
Inputs types: [RandomGeneratorType, TensorType(int64, (0,)), TensorType(int64, ()), TensorType(float64, ()), TensorType(float64, ())]
Inputs shapes: [(), 'No shapes', ()]
Inputs strides: [(), 'No strides', ()]
Inputs values: [array(0.), Generator(PCG64) at 0x17E245FC0, array(1.)]
Outputs clients: [['output'], ['output']]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_32164/1022036140.py", line 22, in <module>
traj, updates = pytensor.scan(step,
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/basic.py", line 852, in scan
raw_inner_outputs = fn(*args)
File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_32164/1022036140.py", line 19, in step
x = pm.Normal.dist(mu=last_x, sigma=sigma)
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/distributions/continuous.py", line 520, in dist
return super().dist([mu, sigma], **kwargs)
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 389, in dist
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
Apply node that caused the error: for{cpu,scan_fn}(TensorConstant{100}, IncSubtensor{Set;:int64:}.0, RandomGeneratorSharedVariable(<Generator(PCG64) at 0x17E245FC0>), TensorConstant{1.0})
Toposort index: 5
Inputs types: [TensorType(int8, ()), TensorType(float64, (101,)), RandomGeneratorType, TensorType(float64, ())]
Inputs shapes: [(), (101,), 'No shapes', ()]
Inputs strides: [(), (8,), 'No strides', ()]
Inputs values: [array(100, dtype=int8), 'not shown', Generator(PCG64) at 0x17E245FC0, array(1.)]
Outputs clients: [[Subtensor{int64::}(for{cpu,scan_fn}.0, ScalarConstant{1})], []]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
return super().run_cell(*args, **kwargs)
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3009, in run_cell
result = self._run_cell(
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3064, in _run_cell
result = runner(coro)
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_32164/1022036140.py", line 22, in <module>
traj, updates = pytensor.scan(step,
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
PyMC version information:
PyMC: 5.3.0
Pytensor: 2.11.1
Possibly related to https://github.com/pymc-devs/pymc/issues/6351
The issue is that there is not a 1-to-1 map between the Scan RV and the Scan value variable (due to the weird output of Scan actually being a Slice I think)
It works with the default backend just fine, I was hoping it was something to do with how the rng was being passed around in the graph during JAX compilation, since that's what the error complains about
Ah okay, that sounds different then
Regardless of this issue, we should definitely clean up the Scan mode thing. It should use whatever mode is being used by the outer function I think
How complex of a fix would that be?
By the way this seems to be triggered by the initial_point function, probably related to the way it seeds things. You can trigger the failure with just model.initial_point()
How complex of a fix would that be?
I don't quite know, but worth a look. An option is to grab the user provided mode and exclude rewrites that are incompatible with JAX (since we know then that we are compiling to JAX. Otherwise we could have an optional kwarg to the dispatch function with the mode that is provided by the JITLinker
Yeah it is trying to feed a numpy generator as input. model.initial_point() is creating a C function but the inner JAX function is being compiled to JAX. I think the solution is indeed to fix the mode thing. This happens in: https://github.com/pymc-devs/pytensor/blob/9ae07ab03bf417bd1c703ec624f494250621e7af/pytensor/link/jax/dispatch/scan.py#L20-L23
This would also fix https://github.com/pymc-devs/pymc/issues/6697 which would be a big improvement.
An immediate solution to your problem is to pass a valid initval to model.register_rv
model.register_rv(traj, name='traj', initval=np.zeros(100))
Good to know! I can try to have a look at the mode problem as well over the next couple days if you're busy with other stuff.
The error happens because model.initial_point evaluates the Scan in python mode, which itself builds an inner function using the Scan mode, which in this case is in JAX.
So far nothing terrible, but this Scan has RNG! which are not compatible with JAX. Usually we convert shared RNGs with a warning, but Scan does not show these as shared to the JAXLinker (they get converted to NominalVariables), so no special hackery is done, and JAX gets the numpy RNGs! The error is more understandable then:
TypeError: Cannot interpret value of type <class 'numpy.random._generator.Generator'> as an abstract array; it does not have a dtype attribute
In more recent version of JAX it's slightly different:
TypeError: Error interpreting argument to <function jax_funcified_fgraph at 0x7f0b24662980> as an abstract array. The problematic value is of type <class 'numpy.random._generator.Generator'> and was passed to the function at path nominal_variable.
I think this would be fixed by https://github.com/pymc-devs/pytensor/pull/278, as the RNG is an explicit input for the purposes of the inner function created by Scan.
However, in general we shouldn't need to define custom modes, and specially custom modes with different linkers internally. Choosing which rewrites get triggered makes a bit more sense perhaps, but the backend?
This wouldn't be fixed by https://github.com/pymc-devs/pytensor/pull/278, because the outputs wouldn't be numpy Generators, and when the scan tried to set the shared variables it would fail. We should not allow Scan in the default backend to use a JAX/PyTorch linker (numba should be fine)
We should just depreciate the mode argument in scan all together no?
Still on the fence whether we want the control of the rewrites, but we can reassess later if the need shows up