equinox
equinox copied to clipboard
better bug hint when writing a simple neural network in equinox
Hi,
Thanks for the nice package. I am new to equinox
. I attempted to write a simple MLP but failed with an error. From the returned information, I am a bit confused on how I should revise my code.
import jax
import jax.numpy as jnp
import equinox as eqx
class MLPeqx(eqx.Module):
def __init__(self, hidden_dims):
super().__init__()
tmp_key = jax.random.split(jax.random.PRNGKey(0), len(hidden_dims) - 1)
self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for i in range(len(hidden_dims) - 1)]
self.activation = jax.nn.relu
def __call__(self, x):
for i in range(len(self.layers) - 1):
x = self.activation(self.layers[i](x))
x = self.layers[-1](x)
return x
MLP = MLPeqx(hidden_dims=[1,2,4,4,2,1])
The error I got:
Traceback (most recent call last):
File "xxxxxx/misc/test1.py", line 18, in <module>
MLP = MLPeqx(hidden_dims=[1,2,4,4,2,1])
File "xxxxxx/python3.9/site-packages/equinox/_module.py", line 548, in __call__
self = super(_ModuleMeta, initable_cls).__call__(*args, **kwargs)
File "xxxxxx/python3.9/site-packages/equinox/_better_abstract.py", line 226, in __call__
self = super().__call__(*args, **kwargs)
File "xxxxxx/python3.9/site-packages/equinox/_module.py", line 376, in __init__
init(self, *args, **kwargs)
File "xxxxxx/misc/test1.py", line 9, in __init__
self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for i in range(len(hidden_dims) - 1)]
File "xxxxxx/python3.9/site-packages/equinox/_module.py", line 811, in __setattr__
raise AttributeError(f"Cannot set attribute {name}")
AttributeError: Cannot set attribute layers
What did I miss?