keras icon indicating copy to clipboard operation
keras copied to clipboard

Overriding `Layer.forward` unexpectedly changes the signature of `Layer.__call__` under torch backend

Open LarsKue opened this issue 1 year ago • 5 comments

In torch, one typically writes layer __call__ methods by overriding the forward method. Under keras, we instead use the call method.

I would not expect overriding forward to have any effect on how a keras.Layer is called, even under the torch backend, since this is the purpose of call. However, it seems when the forward method is overridden, this takes priority over overriding Layer.call.

Minimal Example:

import os
os.environ["KERAS_BACKEND"] = "torch"

import keras


class MyLayer(keras.Layer):
    def call(self, xz, inverse: bool = False):
        if inverse:
            return self.inverse(xz)
        return self.forward(xz)

    def forward(self, x):
        pass

    def inverse(self, z):
        pass


layer = MyLayer()

x = keras.ops.zeros((128, 2))

# TypeError: MyLayer.forward() got an unexpected keyword argument 'inverse'
layer(x, inverse=True)

My keras version: 3.3.3

LarsKue avatar May 19 '24 12:05 LarsKue

Hi @LarsKue ,

Thanks for reporting the issue. I have replicated the issue and observed that with Pytorch as backend its Layer.forward method has preference than base keras.Layer.call method. Attaching gist for reference.

Need to check with Keras Dev team whether this is intended or overlook.

Thanks!

SuryanarayanaY avatar May 20 '24 04:05 SuryanarayanaY

@haifeng-jin any thoughts on this? I'm not familiar enough with the torch to know what our intended behavior is.

mattdangerw avatar May 23 '24 00:05 mattdangerw

To have a Keras layer works with torch Module and training loops, which would call the forward() function for the forward pass, we let the forward() function call the call() function in the base layer class.

haifeng-jin avatar May 23 '24 16:05 haifeng-jin

@LarsKue ,

Please let us know your specific use case and whether it can be solved by add a super().forward() to your forward() implementation.

Thanks

haifeng-jin avatar May 24 '24 18:05 haifeng-jin

@haifeng-jin The use case is invertible networks. My current work-around is not to expose forward and inverse directly, but to use protected methods:

import os
os.environ["KERAS_BACKEND"] = "torch"

import keras


class MyLayer(keras.Layer):
    def call(self, xz, inverse: bool = False):
        if inverse:
            return self._inverse(xz)
        return self._forward(xz)

    def _forward(self, x):
        pass

    def _inverse(self, z):
        pass


layer = MyLayer()

x = keras.ops.zeros((128, 2))

# works now
layer(x, inverse=True)

Adding a super().forward() might work, but I do not think this is a sound solution.

LarsKue avatar May 28 '24 09:05 LarsKue

Hi @LarsKue -

Could you please confirm if this issue is resolved for you ? Please feel free to close the issue if it is resolved!

sonali-kumari1 avatar Feb 03 '25 10:02 sonali-kumari1

Hi @sonali-kumari1 , I would not say it is resolved, but the workaround is sufficient. Thus, closing this.

LarsKue avatar Feb 03 '25 10:02 LarsKue