equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Getting a type error when initializing Linear module (unexpected keyword argument 'key')

Open AbhinavRao23 opened this issue 2 months ago • 1 comments

This may not be the medium for petty user bugs, but I am not able to find any help on general forums. I am not sure what is causing the following type error:

Sample Code:

class MNISTClassifier(eqx.Module):

    model: eqx.nn.Sequential

    def __init__(self, 
                 key: jax.random.PRNGKey):
        keys = jax.random.split(key, 3)
        self.model = eqx.nn.Sequential([
            eqx.nn.Linear(784, 20, key=keys[0]), 
            jax.nn.relu,
            eqx.nn.Linear(20, 20, key=keys[1]), 
            jax.nn.relu,
            eqx.nn.Linear(20, 10, key=keys[2]),
            jax.nn.log_softmax
            ]
        )

    def __call__(self, x):
        return self.model(x)

key  = jax.random.PRNGKey(0)
key, model_key = jax.random.split(key)
model = MNISTClassifier(model_key)
model(jnp.ones(784))

The error I get:

File [/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3259](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3259), in Signature.bind(self, *args, **kwargs)
   [3254](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3254) def bind(self, [/](https://file+.vscode-resource.vscode-cdn.net/), *args, **kwargs):
   [3255](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3255)     """Get a BoundArguments object, that maps the passed `args`
   [3256](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3256)     and `kwargs` to the function's signature.  Raises `TypeError`
   [3257](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3257)     if the passed arguments can not be bound.
   [3258](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3258)     """
-> [3259](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3259)     return self._bind(args, kwargs)

File [/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3248](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3248), in Signature._bind(self, args, kwargs, partial)
   [3246](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3246)         arguments[kwargs_param.name] = kwargs
   [3247](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3247)     else:
-> [3248](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3248)         raise TypeError(
   [3249](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3249)             'got an unexpected keyword argument {arg!r}'.format(
   [3250](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3250)                 arg=next(iter(kwargs))))
   [3252](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/myenv/lib/python3.12/inspect.py:3252) return self._bound_arguments_cls(self, arguments)

TypeError: got an unexpected keyword argument 'key'

I can post full error if required

AbhinavRao23 avatar Apr 29 '24 22:04 AbhinavRao23