vq-vae-2-pytorch
vq-vae-2-pytorch copied to clipboard
Error in code for PixelSnail (Proper masking)
first of all, thx for implementation!
the question is about proper masking inside the model
- shift_down and shift right in the beginning of PixelSnail module have already taken care of masking input, so there is no cheating (direct connection from input to output)
but what the point of masking current pixel in all subsequent layers i.e. gated residual blocks (causal convolutions) and casual attention inside PixelBlock?
p.s. I think that current implementation works somehow only because there are residual connection, and moreover this residual connections certainly break masking of the current pixel (it was preserved just in the beginning)
for comparison: if we consider simple PixelCNN model it uses typeA mask: 1 1 1 1 0 0 0 0 0 only in the beginning
and starting from layer 1, it uses typeB mask: 1 1 1 1 1 0 0 0 0
- could you also clarify why we mask the first (top left) pixel in casual attention mask (I refer to start_mask variable)
I'll try and answer this one after spending some time thinking through the code. First, the masking must be done throughout every layer (a regular convolution would pick up "future" pixels). Whatever argument you use to convince yourself that the mask is needed on the first layer; just run the same argument on the next layer and you should see the need for masking there as well.
There is always a different masking applied to the first layer than the others, since in the first layer we do not want to include the "current" pixel in the output feature map. In subsequent layers the spatial location of the current pixel can be included (since the first layer already removed it)
first layer mask
1 1 1 1 0 0 (middle element is the "current" pixel) 0 0 0
subsequent layers
1 1 1 1 1 0 0 0 0
In the original pixelSNAIL code they avoid using masking altogether (but still capture the same effect by shifting the image around and using different kernel sizes in every layer).
In this implementation it's sort of a hybrid; where masking is used in combination with the shift-and-crop way. It took me a fair while to get my head around it but have some code to visualize it if that's any good.
As a quick aside regarding the masking in this code...
https://github.com/rosinality/vq-vae-2-pytorch/blob/e851d8170709cbe0cdc9521a52f5e0516ffece0c/pixelsnail.py#L115
I feel like this should be self.causal+1 so that the mask is
1 1 1 1 1 1 1 1 0
which appears to preserve the causality through the model better.
I don't believe residual connections are breaking masking in any way; since the masking is used in every convolutional layer with kernel size >1. This can actually be checked directly; if you create a tensor which is 32x32 made from just zeros; and place a 1 in a single position; then run this tensor through the model.predict; any non-zero entries in the output must be locations that are influenced by your chosen pixel; you should find none of these in the pixels' "past"; they should all be in its "future" (where past/future are above-left and below-right of the current pixel).
Note that to get a clean result above you need to switch the bias terms off from each layer.
- Haven't got around to thinking about this one yet...
I also believe based on some checks that the "blind spot" issue originally raised in the gated pixelCNN paper is actually present in this model (despite the use of a horizontal/vertical split kernel on the first layer).