equinox icon indicating copy to clipboard operation
equinox copied to clipboard

PReLU.__call__ doesn't take key argument

Open SimonKoop opened this issue 8 months ago • 2 comments

The call for PReLU doesn't take a key argument, which leads to it being incompatible with Sequential. The easiest fix is to add an (ignored) key argument to the __call__ method. However, it might be worthwhile to have Sequential inspect each layer's call signature to decide whether it should pass a key or not. This would make Sequential work with activations from jax.nn.

SimonKoop avatar May 14 '25 15:05 SimonKoop

If you're okay with this, I'd like to work on a pull request for the second option. But if I'm gonna change something about Sequential, I do have a question: what's the rationale behind using sentinel instead of just using None for stateless models? And if sentinel is really needed, can it maybe be made a public part of Equinox? That would make handling stateful and stateless models with the same functions much easier for third parties.

SimonKoop avatar May 14 '25 15:05 SimonKoop

Take a look at eqx.nn.Lambda, which is a wrapper that exists for this purpose. I decided against inspecting signatures as 'too much magic'.

I no longer recall why there is a sentinel there instead of a None, though. And I don't see any comments / git commits that seem to explain why. Perhaps that could be changed.

patrick-kidger avatar May 14 '25 17:05 patrick-kidger