Overriding `Layer.forward` unexpectedly changes the signature of `Layer.__call__` under torch backend
In torch, one typically writes layer __call__ methods by overriding the forward method. Under keras, we instead use the call method.
I would not expect overriding forward to have any effect on how a keras.Layer is called, even under the torch backend, since this is the purpose of call. However, it seems when the forward method is overridden, this takes priority over overriding Layer.call.
Minimal Example:
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
class MyLayer(keras.Layer):
def call(self, xz, inverse: bool = False):
if inverse:
return self.inverse(xz)
return self.forward(xz)
def forward(self, x):
pass
def inverse(self, z):
pass
layer = MyLayer()
x = keras.ops.zeros((128, 2))
# TypeError: MyLayer.forward() got an unexpected keyword argument 'inverse'
layer(x, inverse=True)
My keras version: 3.3.3
Hi @LarsKue ,
Thanks for reporting the issue. I have replicated the issue and observed that with Pytorch as backend its Layer.forward method has preference than base keras.Layer.call method. Attaching gist for reference.
Need to check with Keras Dev team whether this is intended or overlook.
Thanks!
@haifeng-jin any thoughts on this? I'm not familiar enough with the torch to know what our intended behavior is.
To have a Keras layer works with torch Module and training loops, which would call the forward() function for the forward pass, we let the forward() function call the call() function in the base layer class.
@LarsKue ,
Please let us know your specific use case and whether it can be solved by add a super().forward() to your forward() implementation.
Thanks
@haifeng-jin The use case is invertible networks. My current work-around is not to expose forward and inverse directly, but to use protected methods:
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
class MyLayer(keras.Layer):
def call(self, xz, inverse: bool = False):
if inverse:
return self._inverse(xz)
return self._forward(xz)
def _forward(self, x):
pass
def _inverse(self, z):
pass
layer = MyLayer()
x = keras.ops.zeros((128, 2))
# works now
layer(x, inverse=True)
Adding a super().forward() might work, but I do not think this is a sound solution.
Hi @LarsKue -
Could you please confirm if this issue is resolved for you ? Please feel free to close the issue if it is resolved!
Hi @sonali-kumari1 , I would not say it is resolved, but the workaround is sufficient. Thus, closing this.