equinox
equinox copied to clipboard
Mypy fails to resolve Lambda layer __init__ parameters.
When attempting to use the Lambda
layer in Equinox, it appears that the __init__
method's parameters are not being correctly resolved. This results in a 'too many arguments' error during instantiation.
Does it seem to be an issue with mypy?
equinox: 0.11.2 mypy: 1.7.1 os: macOS
class Model(eqx.Module):
model: eqx.Module
def __init__(self, keys: Keys) -> None:
self.layers = eqx.nn.Sequential(
[
eqx.nn.Linear(90, 300, key=(keys := next_key(keys)).key),
eqx.nn.Lambda(jax.nn.relu),
eqx.nn.Dropout(0.2),
eqx.nn.Linear(300, 200, key=(keys := next_key(keys)).key),
]
)