audio
                                
                                 audio copied to clipboard
                                
                                    audio copied to clipboard
                            
                            
                            
                        FLAC save/load is broken with in-memory buffers and `sox_io` backend
🐛 Describe the bug
from io import BytesIO
import torch
import torchaudio
torch.manual_seed(0)
torchaudio.set_audio_backend("sox_io")
sr = 16000
N = sr  # in case you can't reproduce, try increasing N, for some values it works, for others doesn't
x = torch.rand(1, N)  # "audio"
print(x[0, :20])
# This works OK.
fmt = "wav"
f = BytesIO()
torchaudio.save(f, x, sample_rate=sr, format=fmt)
f.seek(0)
x_rec, _ = torchaudio.load(f, format=fmt)
assert x.shape == x_rec.shape
print("WAV works OK")
# This breaks.
fmt = "flac"
f = BytesIO()
torchaudio.save(f, x, sample_rate=sr, format=fmt)
f.seek(0)
x_rec, _ = torchaudio.load(f, format=fmt)
assert x.shape == x_rec.shape, f"FLAC fails {x.shape=}, {x_rec.shape=}"
The output is:
tensor([0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341, 0.4901, 0.8964, 0.4556,
        0.6323, 0.3489, 0.4017, 0.0223, 0.1689, 0.2939, 0.5185, 0.6977, 0.8000,
        0.1610, 0.2823])
WAV works OK
Traceback (most recent call last):
  File "/Users/pzelasko/Library/Application Support/JetBrains/PyCharm2022.2/scratches/scratch_6.py", line 26, in <module>
    assert x.shape == x_rec.shape, f"FLAC fails {x.shape=}, {x_rec.shape=}"
AssertionError: FLAC ails x.shape=torch.Size([1, 16000]), x_rec.shape=torch.Size([1, 8192])
With sox_io backend, FLAC fails to encode or decode the whole audio from in-memory buffers such as BytesIO. It actually breaks silently, so unless the user validates they get the expect audio size, they will think everything is OK. I found out this issue when using Lhotse, which validates the shapes with metadata.
I can reproduce this issue across multiple PyTorch versions (I think I first noticed it in 1.10, and it was always there since, but it could have been introduced before 1.10), Python versions (3.8 and 3.10), both on Linux and MacOS. Sorry for a late report... I was working around by using the soundfile backend for FLAC which works fine there.
Versions
$ python collect_env.py 
Collecting environment information...
PyTorch version: 1.12.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 12.5.1 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: version 3.24.1
Libc version: N/A
Python version: 3.10.4 (main, Mar 31 2022, 03:37:37) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-12.5.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.3
[pip3] torch==1.12.0
[pip3] torchaudio==0.12.0
[pip3] torchvision==0.13.0
[conda] numpy                     1.22.3          py310hdb36b11_0  
[conda] numpy-base                1.22.3          py310h5e3e9f0_0  
[conda] pytorch                   1.12.0                 py3.10_0    pytorch
[conda] torchaudio                0.12.0                py310_cpu    pytorch
[conda] torchvision               0.13.0                py310_cpu    pytorch
Hi @pzelasko - Thanks for the report. This seems to be issue of loading side.
@pytorch/team-audio-core Can someone look into why this test did not catch this?
Duplicate of https://github.com/pytorch/audio/issues/2356
As a fix, I propose to delegate file-like object support to FFmpeg-based implementation. I switched torchaudio.load to torchaudio.io._compat.load_audio_fileobj and it worked fine and the loaded tensor matches.
https://github.com/pytorch/audio/issues/2356#issuecomment-1241424579
Cool, thanks, I’ll see how I can add this workaround in Lhotse.
Note torchaudio.io._compat has been moved to torchaudio._backend.ffmpeg in #3518, but now that we default to per-call backend dispatching. Please use load(backend="ffmpeg").