vq-vae-2-pytorch icon indicating copy to clipboard operation
vq-vae-2-pytorch copied to clipboard

ResBlock with inplace relu?

Open rfeinman opened this issue 4 years ago • 5 comments
trafficstars

I noticed in the ResBlock class of vqvae.py that you put a ReLU activation at the start of the residual stack, and no ReLU at the end:

class ResBlock(nn.Module):
    def __init__(self, in_channel, channel):
        super().__init__()

        self.conv = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channel, channel, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel, in_channel, 1),
        )

    def forward(self, input):
        out = self.conv(input)
        out += input

        return out

I think the point is to pass forward the pre-activation values from the previous layer as the identity (the skip connection), rather than post-activation. However, since you have inplace=True for the start ReLU, you actually end up getting the latter. Is that intentional or a bug?

rfeinman avatar Jan 22 '21 14:01 rfeinman

Oh...Definitely it is a bug. Thank you! Fixed at ef5f67c.

rosinality avatar Jan 23 '21 01:01 rosinality

@rfeinman How does this affect the result? Can you please explain in detail?

SURABHI-GUPTA avatar Jan 23 '21 13:01 SURABHI-GUPTA

@SURABHI-GUPTA Assume that "input" is the pre-relu activation from the previous layer. The target computation is the following:

output = self.conv(input)
output += input

However the way the code was written, the variable input is being modified in-place by a relu when self.conv(input) is called. So you are actually getting

output = self.conv(input)
output += F.relu(input)

But this has been fixed now with the new commit.

rfeinman avatar Jan 23 '21 17:01 rfeinman

@rfeinman @SURABHI-GUPTA I am also trying to figure out what the implication of this additional operation could be. It could work as an additional skip connection. I can only see training performance (in terms of speed) degradation.

Any further insights?

fostiropoulos avatar Jan 23 '21 22:01 fostiropoulos

@fostiropoulos I can't speak as to whether it is a better choice to pass forward the pre-activation or post-activation values for the skip connection. In popular residual architectures for classification (e.g. ResNet) they do the latter. In this repository, they chose the former.

rfeinman avatar Mar 24 '21 15:03 rfeinman