equinox
equinox copied to clipboard
eqx.Modules cannot be passed as carry into jax.lax.scan
Hi,
not sure if this is really a bug or intended, but it's not possible to pass an eqx.Module
as carry in a jax.lax.scan
. Here is the MVP:
import jax
import jax.numpy as jnp
import equinox as eqx
class SimpleMLP(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, *, key) -> None:
self.mlp = eqx.nn.MLP(in_size=3, out_size=1, width_size=32, depth=2, key=key)
def __call__(self, x):
return self.mlp(x)
key = jax.random.PRNGKey(42)
mlp = SimpleMLP(key=key)
def rollout(mlp, xs):
def step(carry, x):
mlp = carry # just for understanding
val = mlp(x)
carry = mlp
return carry, [val]
_, scan_out = jax.lax.scan(
step,
[mlp],
xs
)
return scan_out
key, subkey = jax.random.split(key)
vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))
This leads to this error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[79], line 32
29 return scan_out
31 key, subkey = jax.random.split(key)
---> 32 vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))
Cell In[79], line 23, in rollout(mlp, xs)
20 carry = mlp
21 return carry, [val]
---> 23 _, scan_out = jax.lax.scan(
24 step,
25 [mlp],
26 xs
27 )
29 return scan_out
[... skipping hidden 5 frame]
File ~/Workspace/jaxRL/.venv/lib/python3.11/site-packages/jax/_src/core.py:1423, in concrete_aval(x)
1421 if hasattr(x, '__jax_array__'):
1422 return concrete_aval(x.__jax_array__())
-> 1423 raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
1424 "type")
TypeError: Value <jax._src.custom_derivatives.custom_jvp object at 0x10f5846d0> with type <class 'jax._src.custom_derivatives.custom_jvp'> is not a valid JAX type
On the other hand, I could just use this (MVP 2):
import jax
import jax.numpy as jnp
import equinox as eqx
class SimpleMLP(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, *, key) -> None:
self.mlp = eqx.nn.MLP(in_size=3, out_size=1, width_size=32, depth=2, key=key)
def __call__(self, x):
return self.mlp(x)
key = jax.random.PRNGKey(42)
mlp = SimpleMLP(key=key)
def rollout(mlp, xs):
def step(carry, x):
val = mlp(x)
return carry, [val]
_, scan_out = jax.lax.scan(
step,
[],
xs
)
return scan_out
key, subkey = jax.random.split(key)
vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))
In MVP 2, I'm using the mlp
from the outer function inside the scan
. While this works, there could be scenarios in which I update the PyTree inside the scan and since no shapes are changed, it made me think that it would be allowed. I don't necessarily want to change the "outer" mlp
from inside the scan function as it's not directly a part of the function (I don't want to change some global states!).
But as already mentioned, I'm not sure if this really is a bug or not.