audio
audio copied to clipboard
`kaldi.fbank` does not work with non-contiguous input when `snip_edges=False`
🐛 Describe the bug
from torchaudio.compliance.kaldi import fbank
import torch
x = torch.rand(1, 16_000 * 2) * (1 << 15)
x = x[:, ::2]
torch.testing.assert_close(fbank(x.contiguous(), snip_edges=False), fbank(x, snip_edges=False))
File ~/miniconda3/envs/vas_2.4/lib/python3.10/site-packages/torchaudio/compliance/kaldi.py:177, in _get_window(waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
174 epsilon = _get_epsilon(device, dtype)
176 # size (m, window_size)
--> 177 strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
179 if dither != 0.0:
180 rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
File ~/miniconda3/envs/vas_2.4/lib/python3.10/site-packages/torchaudio/compliance/kaldi.py:83, in _get_strided(waveform, window_size, window_shift, snip_edges)
80 waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
82 sizes = (m, window_size)
---> 83 return waveform.as_strided(sizes, strides)
RuntimeError: setStorage: sizes [100, 400], strides [320, 2], storage offset 0, and itemsize 4 requiring a storage size of 129916 are out of bounds for storage of size 128480
I encountered this problem while implementing a batched version of kaldi.fbank (btw I'm also willing to contribute my batch support back to torchaudio if the maintainers are interested). The problem lies in _get_strided() function. It first obtains the stride of original waveform
https://github.com/pytorch/audio/blob/332760d4b300f00a0d862e3cfe1495db3b1a14f9/src/torchaudio/compliance/kaldi.py#L61
However, when snip_edges=False, there is a copy via torch.cat(), which forces waveform to be contiguous, if it was not originally so
https://github.com/pytorch/audio/blob/332760d4b300f00a0d862e3cfe1495db3b1a14f9/src/torchaudio/compliance/kaldi.py#L73-L80
Hence, there is a mismatch between the original stride (before padding) and the new stride (padding).
The solution is to move the stride calculation line after padding
# padding...
strides = (window_shift * waveform.stride(0), waveform.stride(0))
sizes = (m, window_size)
return waveform.as_strided(sizes, strides)
Versions
PyTorch 2.4, torchaudio 2.4