equinox
equinox copied to clipboard
Getting a type error when initializing Linear module (unexpected keyword argument 'key')
This may not be the medium for petty user bugs, but I am not able to find any help on general forums. I am not sure what is causing the following type error:
Sample Code:
class MNISTClassifier(eqx.Module):
model: eqx.nn.Sequential
def __init__(self,
key: jax.random.PRNGKey):
keys = jax.random.split(key, 3)
self.model = eqx.nn.Sequential([
eqx.nn.Linear(784, 20, key=keys[0]),
jax.nn.relu,
eqx.nn.Linear(20, 20, key=keys[1]),
jax.nn.relu,
eqx.nn.Linear(20, 10, key=keys[2]),
jax.nn.log_softmax
]
)
def __call__(self, x):
return self.model(x)
key = jax.random.PRNGKey(0)
key, model_key = jax.random.split(key)
model = MNISTClassifier(model_key)
model(jnp.ones(784))
The error I get:
File [/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3259](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3259), in Signature.bind(self, *args, **kwargs)
[3254](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3254) def bind(self, [/](https://file+.vscode-resource.vscode-cdn.net/), *args, **kwargs):
[3255](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3255) """Get a BoundArguments object, that maps the passed `args`
[3256](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3256) and `kwargs` to the function's signature. Raises `TypeError`
[3257](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3257) if the passed arguments can not be bound.
[3258](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3258) """
-> [3259](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3259) return self._bind(args, kwargs)
File [/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3248](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3248), in Signature._bind(self, args, kwargs, partial)
[3246](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3246) arguments[kwargs_param.name] = kwargs
[3247](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3247) else:
-> [3248](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3248) raise TypeError(
[3249](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3249) 'got an unexpected keyword argument {arg!r}'.format(
[3250](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3250) arg=next(iter(kwargs))))
[3252](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3252) return self._bound_arguments_cls(self, arguments)
TypeError: got an unexpected keyword argument 'key'
I can post full error if required