audio icon indicating copy to clipboard operation
audio copied to clipboard

Wav2vec2 output is affected by zero-padding

Open JackPfizer opened this issue 3 years ago • 11 comments

🐛 Describe the bug

I've found that the output of the wav2vec2 pipeline model is bugged, and changes depending on the zero-padding used in batch preprocessing, a simple example Is as follows:

import torchaudio as ta, torch
from torch import Tensor, nn
import types
from typing import Optional, Tuple
torch.manual_seed(1)
model = ta.pipelines.WAV2VEC2_BASE.get_model()

N1=11000
dummy_data1 = torch.randn([1,N1])

output1 = model(dummy_data1,lengths=torch.tensor([N1]))

N2=22000
dummy_data2 = torch.randn([1,N2])
dummy_data = torch.zeros([2,N2])
dummy_data[0,:N1] = dummy_data1
dummy_data[1,:N2] = dummy_data2

output2 = model(dummy_data,lengths=torch.tensor([N1,N2]))

frames1 = output1[1][0]
print(torch.norm(output1[0][0,:frames1]-output2[0][0,:frames1]))

Which gives the output of tensor(68.1875, grad_fn=<CopyBackwards>). Changing the value of N2 will change this value further. I've found the source to be the group norm layer after the first convolution in the feature extractor, as it applies group norm across the whole sequence irrespective of it being padding. To amend this, I've created a masked group norm function to only apply normalisation across the actual sequence.

def lengths_to_mask(lengths, max_len=None, dtype=None):
    """
    Converts a "lengths" tensor to its binary mask representation.
    
    Based on: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397
     
    :lengths: N-dimensional tensor
    :returns: N*max_len dimensional tensor. If max_len==None, max_len=max(lengtsh)
    """
    assert len(lengths.shape) == 1, 'Length shape should be 1 dimensional.'
    max_len = max_len or lengths.max().item()
    mask = torch.arange(
        max_len,
        device=lengths.device,
        dtype=lengths.dtype)\
    .expand(len(lengths), max_len) < lengths.unsqueeze(1)
    if dtype is not None:
        mask = torch.as_tensor(mask, dtype=dtype, device=lengths.device)
    return mask

class MaskedGroupNorm(nn.GroupNorm):
    """
    Masked verstion of the Group normalization.
    
    Based on: https://github.com/ptrblck/pytorch_misc/blob/20e8ea93bd458b88f921a87e2d4001a4eb753a02/batch_norm_manual.py
    
    Receives a N-dim tensor of sequence lengths per batch element
    along with the regular input for masking.
    
    Check pytorch's GroupNorm implementation for argument details.
    """
    def __init__(self, num_groups, num_features, eps=1e-5,affine=True):
        super(MaskedGroupNorm, self).__init__(
            num_groups,
            num_features,
            eps,
            affine
        )

    def forward(self, inp, lengths):
        
        # We transform the mask into a sort of P(inp) with equal probabilities
        # for all unmasked elements of the tensor, and 0 probability for masked
        # ones.
        
        assert inp.shape[1]%self.num_groups == 0, 'Feature size not divisible by groups'

        mask = lengths_to_mask(lengths, max_len=inp.shape[-1], dtype=inp.dtype)
        ave_mask = mask / lengths[:,None] / (inp.shape[-2] / self.num_groups) #also features
        ave_mask = ave_mask.unsqueeze(1)#.expand(inp.shape)

        # Here lies the trick. Using Var(X) = E[X^2] - E[X]^2 as the biased
        # variance, we do not need to make any tensor shape manipulation.
        # mean = E[X] is simply the sum-product of our "probability" mask with the input...
        inp = inp*mask.unsqueeze(1) #mask out any extra bits of data - such as those left from conv bleeding
        inp_r = inp.reshape([inp.shape[0],self.num_groups,-1,inp.shape[-1]])
        ave_mask = ave_mask.unsqueeze(2)
        mean = (ave_mask * inp_r).sum([2, 3])
        # ...whereas Var(X) is directly derived from the above formulae
        # This should be numerically equivalent to the biased sample variance
        var = (ave_mask * inp_r ** 2).sum([2, 3]) - mean ** 2

        inp_r = (inp_r - mean[:,:,None,None]) / (torch.sqrt(var[:, :, None, None] + self.eps))
        out = inp_r.reshape(inp.shape)
        if self.affine:
            out = out * self.weight[None, :, None] + self.bias[None, :, None]
        return out * mask.unsqueeze(1)

def masked_conv_forward(
        self,
        x: Tensor,
        length: Optional[Tensor],
    ) -> Tuple[Tensor, Optional[Tensor]]:
        """
        A method to overwrite the wav2vec2 forward function in its feature_extractor.conv_layers[0]
        as it performs differently when there  is extra zero padding
        Args:
            x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
            length (Tensor or None, optional): Shape ``[batch, ]``.
        Returns:
            Tensor: Shape ``[batch, out_channels, out_frames]``.
            Optional[Tensor]: Shape ``[batch, ]``.
        """
        x = self.conv(x)

        if length is not None:
            length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1
            # When input length is 0, the resulting length can be negative. So fix it here.
            length = torch.max(torch.zeros_like(length), length)

        if self.layer_norm is not None:
            if isinstance(self.layer_norm, MaskedGroupNorm):
                x = self.layer_norm(x,length)
            else:
                x = self.layer_norm(x)
        x = nn.functional.gelu(x)

        return x, length

This can be added to the model by overloading the preexisting group norm layer, whilst copying over the group norm parameters from the pretrained model. This also requires a new forward call for the model.

prior_params = vars(model.feature_extractor.conv_layers[0].layer_norm)
model.feature_extractor.conv_layers[0].layer_norm = MaskedGroupNorm(model.feature_extractor.conv_layers[0].layer_norm.num_groups,
                                                                    model.feature_extractor.conv_layers[0].layer_norm.num_channels)
model.feature_extractor.conv_layers[0].layer_norm .__dict__.update(prior_params)
model.feature_extractor.conv_layers[0].forward = types.MethodType(masked_conv_forward, model.feature_extractor.conv_layers[0])



output1 = model(dummy_data1,lengths=torch.tensor([N1]))

output2 = model(dummy_data,lengths=torch.tensor([N1,N2]))

print(torch.norm(output1[0][0,:frames1]-output2[0][0,:frames1]))

Which gives the output of tensor(5.6603e-05, grad_fn=<CopyBackwards>)

Versions

Collecting environment information... PyTorch version: 1.10.0 Is debug build: False CUDA used to build PyTorch: 11.3 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64) GCC version: (GCC) 10.3.0 Clang version: 10.0.0-4ubuntu1 CMake version: version 3.16.3 Libc version: glibc-2.31

Python version: 3.9.7 (default, Sep 16 2021, 13:09:58) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-5.4.0-97-generic-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: Could not collect GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2070

Nvidia driver version: 510.47.03 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A

Versions of relevant libraries: [pip3] numpy==1.20.3 [pip3] torch==1.10.0 [pip3] torchaudio==0.10.0 [pip3] torchvision==0.11.1 [conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.19.5 pypi_0 pypi [conda] numpy-base 1.20.3 py39h74d4b33_0
[conda] pytorch 1.10.0 py3.9_cuda11.3_cudnn8.2.0_0 pytorch [conda] pytorch-mutex 1.0 cuda pytorch [conda] torchaudio 0.10.0 py39_cu113 pytorch [conda] torchvision 0.11.1 py39_cu113 pytorch

JackPfizer avatar Feb 16 '22 00:02 JackPfizer

Just to check that the masked_group_norm function is working, I tested the first output1 against the new model's output prediction:

#output1 is based on the existing ta pipeline, output1_1 is based on my amendments

output1_1 = model(dummy_data1,lengths=torch.tensor([N1]))

print(torch.norm(output1[0][0,:frames1]-output1_1[0][0,:frames1]))

Which gives tensor(6.2804e-05, grad_fn=<CopyBackwards>). So it's not perfect, but is reasonably close.

JackPfizer avatar Feb 16 '22 00:02 JackPfizer

Hi @JackRAHealth

Thanks for the report.et me look into it.

mthrok avatar Feb 16 '22 01:02 mthrok

The given analysis seems to be correct and the proper solution would be implement the normalization that is aware of masking.

There are tests for batch consistency but they only use the samples with the similar lengths, so this effect was not caught.

https://github.com/pytorch/audio/blob/cbf1b8392341c61ea1db9daf81556b207e2cf9eb/test/torchaudio_unittest/models/wav2vec2/model_test.py#L139-L147

We need to update two modules (torch.nn.GroupNorm and LayerNorm) so that they support masking.

https://github.com/pytorch/audio/blob/cbf1b8392341c61ea1db9daf81556b207e2cf9eb/torchaudio/models/wav2vec2/components.py#L520-L530

mthrok avatar Feb 21 '22 02:02 mthrok

I am un-assigning myself, as this turned out to require more resources than I have at the moment. If anyone is interested in resolving this, let us know. (please do discuss some detail before making a PR.)

mthrok avatar Feb 21 '22 03:02 mthrok

good pickup @JackRAHealth, keeping it 55th street

anicolson avatar Feb 21 '22 05:02 anicolson

This issue seems very serious!!! as the underlying problem comes from nn.GroupNorm of PyTorch... Its scope goes beyond Torchaudio.

As far as I know, Huggingface's Wav2vec 2.0 model is also implemented with nn.GroupNorm. It means that recent works might have wrongly reported their performances by using Wav2vec 2.0 model of Torchaudio or Huggingface as the baseline.

kss2517 avatar Jul 22 '22 14:07 kss2517

As far as I know, Huggingface's Wav2vec 2.0 model is also implemented with nn.GroupNorm. It means that recent works might have wrongly reported their performances by using Wav2vec 2.0 model of Torchaudio or Huggingface as the baseline.

The original fariseq implementation also uses nn.GroupNorm at the core, so this was an issue from the very beginning. We are thinking that the adaptation of NestedTensor is a way to solve this.

cc @cpuhrsch

mthrok avatar Jul 22 '22 17:07 mthrok

cc @jbschlosser who is the TL for NestedTensor

cpuhrsch avatar Jul 22 '22 17:07 cpuhrsch