audio icon indicating copy to clipboard operation
audio copied to clipboard

`kaldi.fbank` does not work with non-contiguous input when `snip_edges=False`

Open gau-nernst opened this issue 11 months ago • 0 comments

🐛 Describe the bug

from torchaudio.compliance.kaldi import fbank
import torch


x = torch.rand(1, 16_000 * 2) * (1 << 15)
x = x[:, ::2]
torch.testing.assert_close(fbank(x.contiguous(), snip_edges=False), fbank(x, snip_edges=False))
File ~/miniconda3/envs/vas_2.4/lib/python3.10/site-packages/torchaudio/compliance/kaldi.py:177, in _get_window(waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
    174 epsilon = _get_epsilon(device, dtype)
    176 # size (m, window_size)
--> 177 strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
    179 if dither != 0.0:
    180     rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)

File ~/miniconda3/envs/vas_2.4/lib/python3.10/site-packages/torchaudio/compliance/kaldi.py:83, in _get_strided(waveform, window_size, window_shift, snip_edges)
     80         waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
     82 sizes = (m, window_size)
---> 83 return waveform.as_strided(sizes, strides)

RuntimeError: setStorage: sizes [100, 400], strides [320, 2], storage offset 0, and itemsize 4 requiring a storage size of 129916 are out of bounds for storage of size 128480

I encountered this problem while implementing a batched version of kaldi.fbank (btw I'm also willing to contribute my batch support back to torchaudio if the maintainers are interested). The problem lies in _get_strided() function. It first obtains the stride of original waveform

https://github.com/pytorch/audio/blob/332760d4b300f00a0d862e3cfe1495db3b1a14f9/src/torchaudio/compliance/kaldi.py#L61

However, when snip_edges=False, there is a copy via torch.cat(), which forces waveform to be contiguous, if it was not originally so

https://github.com/pytorch/audio/blob/332760d4b300f00a0d862e3cfe1495db3b1a14f9/src/torchaudio/compliance/kaldi.py#L73-L80

Hence, there is a mismatch between the original stride (before padding) and the new stride (padding).

The solution is to move the stride calculation line after padding

    # padding...

    strides = (window_shift * waveform.stride(0), waveform.stride(0))
    sizes = (m, window_size)
    return waveform.as_strided(sizes, strides)

Versions

PyTorch 2.4, torchaudio 2.4

gau-nernst avatar Nov 25 '24 02:11 gau-nernst