equinox icon indicating copy to clipboard operation
equinox copied to clipboard

eqx.Modules cannot be passed as carry into jax.lax.scan

Open Artur-Galstyan opened this issue 6 months ago • 12 comments

Hi,

not sure if this is really a bug or intended, but it's not possible to pass an eqx.Module as carry in a jax.lax.scan. Here is the MVP:

import jax
import jax.numpy as jnp
import equinox as eqx

class SimpleMLP(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, *, key) -> None:
        self.mlp = eqx.nn.MLP(in_size=3, out_size=1, width_size=32, depth=2, key=key)
    
    def __call__(self, x):
        return self.mlp(x)

key = jax.random.PRNGKey(42)
mlp = SimpleMLP(key=key)

def rollout(mlp, xs):
    def step(carry, x):
        mlp = carry # just for understanding 
        val = mlp(x)
        carry = mlp
        return carry, [val]
    
    _, scan_out = jax.lax.scan(
        step,
        [mlp],
        xs
    )
    
    return scan_out

key, subkey = jax.random.split(key)
vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))

This leads to this error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[79], line 32
     29     return scan_out
     31 key, subkey = jax.random.split(key)
---> 32 vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))

Cell In[79], line 23, in rollout(mlp, xs)
     20     carry = mlp
     21     return carry, [val]
---> 23 _, scan_out = jax.lax.scan(
     24     step,
     25     [mlp],
     26     xs
     27 )
     29 return scan_out

    [... skipping hidden 5 frame]

File ~/Workspace/jaxRL/.venv/lib/python3.11/site-packages/jax/_src/core.py:1423, in concrete_aval(x)
   1421 if hasattr(x, '__jax_array__'):
   1422   return concrete_aval(x.__jax_array__())
-> 1423 raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
   1424                  "type")

TypeError: Value <jax._src.custom_derivatives.custom_jvp object at 0x10f5846d0> with type <class 'jax._src.custom_derivatives.custom_jvp'> is not a valid JAX type

On the other hand, I could just use this (MVP 2):

import jax
import jax.numpy as jnp
import equinox as eqx

class SimpleMLP(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, *, key) -> None:
        self.mlp = eqx.nn.MLP(in_size=3, out_size=1, width_size=32, depth=2, key=key)
    
    def __call__(self, x):
        return self.mlp(x)

key = jax.random.PRNGKey(42)
mlp = SimpleMLP(key=key)

def rollout(mlp, xs):
    def step(carry, x):
        val = mlp(x)
        return carry, [val]
    
    _, scan_out = jax.lax.scan(
        step,
        [],
        xs
    )
    
    return scan_out

key, subkey = jax.random.split(key)
vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))

In MVP 2, I'm using the mlp from the outer function inside the scan. While this works, there could be scenarios in which I update the PyTree inside the scan and since no shapes are changed, it made me think that it would be allowed. I don't necessarily want to change the "outer" mlp from inside the scan function as it's not directly a part of the function (I don't want to change some global states!).

But as already mentioned, I'm not sure if this really is a bug or not.

Artur-Galstyan avatar Dec 29 '23 11:12 Artur-Galstyan