aesara
aesara copied to clipboard
Errors with `jax.random.normal`'s `shape` parameter in version 0.2.26
Version 0.2.26 of JAX is causing new "omnistaging"-like errors in tests.link.test_jax:test_random_stats. Here's the relevant portion of the traceback:
...
aesara/link/basic.py:705: in thunk
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
/tmp/user/1000/tmprconxvmr:3: in jax_funcified_fgraph
auto_41, auto_42 = random_variable(auto_36, auto_37, auto_38, auto_39, auto_40)
aesara/link/jax/dispatch.py:1054: in random_variable
data = getattr(jax.random, name)(key=prng, shape=size)
../../../../apps/anaconda3/envs/aesara-3.7/lib/python3.7/site-packages/jax/_src/random.py:522: in normal
shape = core.as_named_shape(shape)
E TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>,).
E If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
...
Oddly, in this case the actual shape value is DeviceArray([10000], dtype=int64). Converting it to back to a non-DeviceArray (e.g. jax.device_get) causes the failing code to work again.
Reproducing the error that is observed:
import jax
shape = jax.numpy.array([1000])
def jax_funcified(prng_key):
return jax.random.normal(prng_key, shape)
key = jax.random.PRNGKey(0)
jax.jit(jax_funcified)(key)
# TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>,). If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
But the following works:
import jax
import numpy as np
shape = np.array([10])
def jax_funcified(prng_key):
return jax.random.normal(prng_key, shape)
key = jax.random.PRNGKey(0)
print(jax.jit(jax_funcified)(key))
# [-0.3721109 0.26423115 -0.18252768 -0.7368197 -0.44030377 -0.1521442
# -0.67135346 -0.5908641 0.73168886 0.5673026 ]
A way to fix this is to only convert np.ndarrays to DeviceArrays when they don't represent a shape.
This seems to be linked to #182
Solved in #1284.