equinox
equinox copied to clipboard
optim.init fails if not filtering by inexact array
Why does this fail?:
import optax
import equinox as eqx
m = eqx.nn.MLP(1,1,2,1, key=PRNGKey(0))
optim = optax.adamw(1e-3)
optim.init(eqx.filter(m,eqx.is_inexact_array)) # this does not fail
optim.init(m)
whereas in the bert example there's no need to filter array?