audio
audio copied to clipboard
FFT frequency bins obtained by `torch.linsapce` in `torchaudio.functional.melscale_fbanks`
🐛 Describe the bug
I notice that in https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/functional.py#L561-L568,
if norm is not None and norm != "slaney":
raise ValueError('norm must be one of None or "slaney"')
# freq bins
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
# calculate mel freq bins
torch.linspace
is used to obtain the FFT frequency bins. This is inconsistent with the implementation in librosa, which uses
# Center freqs of each FFT bin
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)
The above is equivalent to
np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
The counterpart in PyTorch is
torch.fft.rfftfreq(n=n_fft, d=1.0 / sample_rate)
As a result, the current one will generate frequency bins with a different interval compared to that in librosa. For example:
expand
>>> sr = sample_rate = 16000
>>> n_fft = 321
>>> n_freqs = n_fft // 2 + 1
>>> all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
>>> print(all_freqs)
tensor([ 0., 50., 100., 150., 200., 250., 300., 350., 400., 450.,
500., 550., 600., 650., 700., 750., 800., 850., 900., 950.,
1000., 1050., 1100., 1150., 1200., 1250., 1300., 1350., 1400., 1450.,
1500., 1550., 1600., 1650., 1700., 1750., 1800., 1850., 1900., 1950.,
2000., 2050., 2100., 2150., 2200., 2250., 2300., 2350., 2400., 2450.,
2500., 2550., 2600., 2650., 2700., 2750., 2800., 2850., 2900., 2950.,
3000., 3050., 3100., 3150., 3200., 3250., 3300., 3350., 3400., 3450.,
3500., 3550., 3600., 3650., 3700., 3750., 3800., 3850., 3900., 3950.,
4000., 4050., 4100., 4150., 4200., 4250., 4300., 4350., 4400., 4450.,
4500., 4550., 4600., 4650., 4700., 4750., 4800., 4850., 4900., 4950.,
5000., 5050., 5100., 5150., 5200., 5250., 5300., 5350., 5400., 5450.,
5500., 5550., 5600., 5650., 5700., 5750., 5800., 5850., 5900., 5950.,
6000., 6050., 6100., 6150., 6200., 6250., 6300., 6350., 6400., 6450.,
6500., 6550., 6600., 6650., 6700., 6750., 6800., 6850., 6900., 6950.,
7000., 7050., 7100., 7150., 7200., 7250., 7300., 7350., 7400., 7450.,
7500., 7550., 7600., 7650., 7700., 7750., 7800., 7850., 7900., 7950.,
8000.])
>>> print(np.fft.rfftfreq(n=n_fft, d=1.0 / sr))
[ 0. 49.84423676 99.68847352 149.53271028 199.37694704
249.2211838 299.06542056 348.90965732 398.75389408 448.59813084
498.4423676 548.28660436 598.13084112 647.97507788 697.81931464
747.6635514 797.50778816 847.35202492 897.19626168 947.04049844
996.8847352 1046.72897196 1096.57320872 1146.41744548 1196.26168224
1246.105919 1295.95015576 1345.79439252 1395.63862928 1445.48286604
1495.3271028 1545.17133956 1595.01557632 1644.85981308 1694.70404984
1744.5482866 1794.39252336 1844.23676012 1894.08099688 1943.92523364
1993.7694704 2043.61370717 2093.45794393 2143.30218069 2193.14641745
2242.99065421 2292.83489097 2342.67912773 2392.52336449 2442.36760125
2492.21183801 2542.05607477 2591.90031153 2641.74454829 2691.58878505
2741.43302181 2791.27725857 2841.12149533 2890.96573209 2940.80996885
2990.65420561 3040.49844237 3090.34267913 3140.18691589 3190.03115265
3239.87538941 3289.71962617 3339.56386293 3389.40809969 3439.25233645
3489.09657321 3538.94080997 3588.78504673 3638.62928349 3688.47352025
3738.31775701 3788.16199377 3838.00623053 3887.85046729 3937.69470405
3987.53894081 4037.38317757 4087.22741433 4137.07165109 4186.91588785
4236.76012461 4286.60436137 4336.44859813 4386.29283489 4436.13707165
4485.98130841 4535.82554517 4585.66978193 4635.51401869 4685.35825545
4735.20249221 4785.04672897 4834.89096573 4884.73520249 4934.57943925
4984.42367601 5034.26791277 5084.11214953 5133.95638629 5183.80062305
5233.64485981 5283.48909657 5333.33333333 5383.17757009 5433.02180685
5482.86604361 5532.71028037 5582.55451713 5632.39875389 5682.24299065
5732.08722741 5781.93146417 5831.77570093 5881.61993769 5931.46417445
5981.30841121 6031.15264798 6080.99688474 6130.8411215 6180.68535826
6230.52959502 6280.37383178 6330.21806854 6380.0623053 6429.90654206
6479.75077882 6529.59501558 6579.43925234 6629.2834891 6679.12772586
6728.97196262 6778.81619938 6828.66043614 6878.5046729 6928.34890966
6978.19314642 7028.03738318 7077.88161994 7127.7258567 7177.57009346
7227.41433022 7277.25856698 7327.10280374 7376.9470405 7426.79127726
7476.63551402 7526.47975078 7576.32398754 7626.1682243 7676.01246106
7725.85669782 7775.70093458 7825.54517134 7875.3894081 7925.23364486
7975.07788162]
>>> print(torch.fft.rfftfreq(n=n_fft, d=1.0 / sample_rate))
tensor([ 0.0000, 49.8442, 99.6885, 149.5327, 199.3770, 249.2212,
299.0654, 348.9097, 398.7539, 448.5981, 498.4424, 548.2866,
598.1309, 647.9751, 697.8193, 747.6636, 797.5078, 847.3521,
897.1963, 947.0405, 996.8848, 1046.7290, 1096.5732, 1146.4175,
1196.2617, 1246.1060, 1295.9502, 1345.7944, 1395.6387, 1445.4829,
1495.3271, 1545.1714, 1595.0156, 1644.8599, 1694.7041, 1744.5483,
1794.3926, 1844.2368, 1894.0811, 1943.9253, 1993.7695, 2043.6138,
2093.4580, 2143.3022, 2193.1465, 2242.9907, 2292.8350, 2342.6792,
2392.5234, 2442.3677, 2492.2119, 2542.0562, 2591.9004, 2641.7446,
2691.5889, 2741.4331, 2791.2773, 2841.1216, 2890.9658, 2940.8101,
2990.6543, 3040.4985, 3090.3428, 3140.1870, 3190.0312, 3239.8755,
3289.7197, 3339.5640, 3389.4082, 3439.2524, 3489.0967, 3538.9409,
3588.7852, 3638.6294, 3688.4736, 3738.3179, 3788.1621, 3838.0063,
3887.8506, 3937.6948, 3987.5391, 4037.3833, 4087.2275, 4137.0718,
4186.9160, 4236.7603, 4286.6045, 4336.4487, 4386.2930, 4436.1372,
4485.9814, 4535.8257, 4585.6699, 4635.5142, 4685.3584, 4735.2026,
4785.0469, 4834.8911, 4884.7354, 4934.5796, 4984.4238, 5034.2681,
5084.1123, 5133.9565, 5183.8008, 5233.6450, 5283.4893, 5333.3335,
5383.1777, 5433.0220, 5482.8662, 5532.7104, 5582.5547, 5632.3989,
5682.2432, 5732.0874, 5781.9316, 5831.7759, 5881.6201, 5931.4644,
5981.3086, 6031.1528, 6080.9971, 6130.8413, 6180.6855, 6230.5298,
6280.3740, 6330.2183, 6380.0625, 6429.9067, 6479.7510, 6529.5952,
6579.4395, 6629.2837, 6679.1279, 6728.9722, 6778.8164, 6828.6606,
6878.5049, 6928.3491, 6978.1934, 7028.0376, 7077.8818, 7127.7261,
7177.5703, 7227.4146, 7277.2588, 7327.1030, 7376.9473, 7426.7915,
7476.6357, 7526.4800, 7576.3242, 7626.1685, 7676.0127, 7725.8569,
7775.7012, 7825.5454, 7875.3896, 7925.2339, 7975.0781])
I also checked the impact of this line of code on the resultant mel-scale conversion matrix in torchaudio.functional.melscale_fbanks
.
After replacing the current one with
# freq bins
all_freqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sample_rate)
the results are more consistent with librosa's output.
>>> sr = 16000
>>> n_fft = 321
>>> hop_length = 160
>>> n_mels = 120
>>> torch.random.manual_seed(0)
>>> x = torch.randn(16000)
>>> print(x)
tensor([-1.1258, -1.1524, -0.2506, ..., 0.1347, 0.7971, 0.2949])
################################################################
# Original code
################################################################
>>> out1_th = torchaudio.transforms.MelSpectrogram(sample_rate=sr, hop_length=hop_length, n_fft=n_fft, n_mels=n_mels, pad_mode="constant", norm="slaney", mel_scale="slaney")(x)
>>> out1_ref = librosa.feature.melspectrogram(y=x.numpy(), sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
>>> np.testing.assert_allclose(out1_th.numpy(), out1_ref, atol=3e-5)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=3e-05
Mismatched elements: 11999 / 12000 (100%)
Max absolute difference: 3.4721794
Max relative difference: 12117.599
x: array([[0.000000e+00, 0.000000e+00, 0.000000e+00, ..., 0.000000e+00,
0.000000e+00, 0.000000e+00],
[1.446359e+00, 1.172625e-01, 1.025011e+01, ..., 3.962148e+00,...
y: array([[7.713332e-04, 6.253528e-05, 5.466319e-03, ..., 2.112987e-03,
3.167379e-03, 2.537756e-03],
[1.453905e+00, 1.178743e-01, 1.030360e+01, ..., 3.982821e+00,...
>>> out2_th = torchaudio.transforms.MelSpectrogram(sample_rate=sr, hop_length=hop_length, n_fft=n_fft, n_mels=n_mels, pad_mode="constant")(x)
>>> out2_ref = librosa.feature.melspectrogram(y=x.numpy(), sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, norm=None, htk=True)
>>> np.testing.assert_allclose(out2_th.numpy(), out2_ref, atol=3e-5)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=3e-05
Mismatched elements: 11400 / 12000 (95%)
Max absolute difference: 282.45605
Max relative difference: 2.562869
x: array([[ 0. , 0. , 0. , ..., 0. , 0. ,
0. ],
[ 0. , 0. , 0. , ..., 0. , 0. ,...
y: array([[ 0. , 0. , 0. , ..., 0. , 0. ,
0. ],
[ 0. , 0. , 0. , ..., 0. , 0. ,...
################################################################
# After replaced by ` torch.fft.rfftfreq`
################################################################
>>> out1_th = torchaudio.transforms.MelSpectrogram(sample_rate=sr, hop_length=hop_length, n_fft=n_fft, n_mels=n_mels, pad_mode="constant", norm="slaney", mel_scale="slaney")(x)
>>> out1_ref = librosa.feature.melspectrogram(y=x.numpy(), sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
>>> np.testing.assert_allclose(out1_th.numpy(), out1_ref, atol=3e-5)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=3e-05
Mismatched elements: 639 / 12000 (5.33%)
Max absolute difference: 0.00011086
Max relative difference: 0.00052703
x: array([[7.710888e-04, 6.251546e-05, 5.464583e-03, ..., 2.112317e-03,
3.166373e-03, 2.536952e-03],
[1.453906e+00, 1.178743e-01, 1.030360e+01, ..., 3.982822e+00,...
y: array([[7.713332e-04, 6.253528e-05, 5.466319e-03, ..., 2.112987e-03,
3.167379e-03, 2.537756e-03],
[1.453905e+00, 1.178743e-01, 1.030360e+01, ..., 3.982821e+00,...
>>> out2_th = torchaudio.transforms.MelSpectrogram(sample_rate=sr, hop_length=hop_length, n_fft=n_fft, n_mels=n_mels, pad_mode="constant")(x)
>>> out2_ref = librosa.feature.melspectrogram(y=x.numpy(), sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, norm=None, htk=True)
>>> np.testing.assert_allclose(out2_th.numpy(), out2_ref, atol=3e-5)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=3e-05
Mismatched elements: 9897 / 12000 (82.5%)
Max absolute difference: 0.02191162
Max relative difference: 0.00012979
x: array([[ 0. , 0. , 0. , ..., 0. , 0. ,
0. ],
[ 0. , 0. , 0. , ..., 0. , 0. ,...
y: array([[ 0. , 0. , 0. , ..., 0. , 0. ,
0. ],
[ 0. , 0. , 0. , ..., 0. , 0. ,...
Versions
PyTorch version: 2.0.1 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A
OS: macOS 12.6.4 (x86_64) GCC version: Could not collect Clang version: 14.0.0 (clang-1400.0.29.202) CMake version: version 3.27.0 Libc version: N/A
Python version: 3.9.17 (main, Jun 20 2023, 17:04:52) [Clang 14.0.0 (clang-1400.0.29.202)] (64-bit runtime) Python platform: macOS-12.6.4-x86_64-i386-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
Versions of relevant libraries: [pip3] numpy==1.23.5 [pip3] torch==2.0.1 [pip3] torchaudio==2.0.2
I don't understand the issue. I've got the same result from the two functions you say are inconsistent:
In [19]: import torch
...: import librosa
...:
...: # INPUT:
...:
...: sr = 16000
...: n_fft = 16
...:
...: # PROCESS:
...: n_freqs = 1 + n_fft // 2
...:
...: freqs_torch = torch.linspace(0, sr // 2, n_freqs)
...:
...: freqs_librosa = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
In [20]: freqs_librosa
Out[20]: array([ 0., 1000., 2000., 3000., 4000., 5000., 6000., 7000., 8000.])
In [21]: freqs_torch
Out[21]: tensor([ 0., 1000., 2000., 3000., 4000., 5000., 6000., 7000., 8000.])
Yes, they are the same in your example when n_fft = 16
. But it's not always true. In my example, when n_fft=321
(as used in DNSMOS computation), they have different outputs:
In [1]: import torch
...: import librosa
...:
...: # INPUT:
...:
...: sr = 16000
...: n_fft = 16
...:
...: # PROCESS:
...: n_freqs = 1 + n_fft // 2
...:
...: freqs_torch = torch.linspace(0, sr // 2, n_freqs)
...:
...: freqs_librosa = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
In [2]: freqs_librosa
Out[2]:
array([ 0. , 49.84423676, 99.68847352, 149.53271028,
199.37694704, 249.2211838 , 299.06542056, 348.90965732,
398.75389408, 448.59813084, 498.4423676 , 548.28660436,
598.13084112, 647.97507788, 697.81931464, 747.6635514 ,
797.50778816, 847.35202492, 897.19626168, 947.04049844,
996.8847352 , 1046.72897196, 1096.57320872, 1146.41744548,
1196.26168224, 1246.105919 , 1295.95015576, 1345.79439252,
1395.63862928, 1445.48286604, 1495.3271028 , 1545.17133956,
1595.01557632, 1644.85981308, 1694.70404984, 1744.5482866 ,
1794.39252336, 1844.23676012, 1894.08099688, 1943.92523364,
1993.7694704 , 2043.61370717, 2093.45794393, 2143.30218069,
2193.14641745, 2242.99065421, 2292.83489097, 2342.67912773,
2392.52336449, 2442.36760125, 2492.21183801, 2542.05607477,
2591.90031153, 2641.74454829, 2691.58878505, 2741.43302181,
2791.27725857, 2841.12149533, 2890.96573209, 2940.80996885,
2990.65420561, 3040.49844237, 3090.34267913, 3140.18691589,
3190.03115265, 3239.87538941, 3289.71962617, 3339.56386293,
3389.40809969, 3439.25233645, 3489.09657321, 3538.94080997,
3588.78504673, 3638.62928349, 3688.47352025, 3738.31775701,
3788.16199377, 3838.00623053, 3887.85046729, 3937.69470405,
3987.53894081, 4037.38317757, 4087.22741433, 4137.07165109,
4186.91588785, 4236.76012461, 4286.60436137, 4336.44859813,
4386.29283489, 4436.13707165, 4485.98130841, 4535.82554517,
4585.66978193, 4635.51401869, 4685.35825545, 4735.20249221,
4785.04672897, 4834.89096573, 4884.73520249, 4934.57943925,
4984.42367601, 5034.26791277, 5084.11214953, 5133.95638629,
5183.80062305, 5233.64485981, 5283.48909657, 5333.33333333,
5383.17757009, 5433.02180685, 5482.86604361, 5532.71028037,
5582.55451713, 5632.39875389, 5682.24299065, 5732.08722741,
5781.93146417, 5831.77570093, 5881.61993769, 5931.46417445,
5981.30841121, 6031.15264798, 6080.99688474, 6130.8411215 ,
6180.68535826, 6230.52959502, 6280.37383178, 6330.21806854,
6380.0623053 , 6429.90654206, 6479.75077882, 6529.59501558,
6579.43925234, 6629.2834891 , 6679.12772586, 6728.97196262,
6778.81619938, 6828.66043614, 6878.5046729 , 6928.34890966,
6978.19314642, 7028.03738318, 7077.88161994, 7127.7258567 ,
7177.57009346, 7227.41433022, 7277.25856698, 7327.10280374,
7376.9470405 , 7426.79127726, 7476.63551402, 7526.47975078,
7576.32398754, 7626.1682243 , 7676.01246106, 7725.85669782,
7775.70093458, 7825.54517134, 7875.3894081 , 7925.23364486,
7975.07788162])
In [3]: freqs_torch
Out[3]:
tensor([ 0., 50., 100., 150., 200., 250., 300., 350., 400., 450.,
500., 550., 600., 650., 700., 750., 800., 850., 900., 950.,
1000., 1050., 1100., 1150., 1200., 1250., 1300., 1350., 1400., 1450.,
1500., 1550., 1600., 1650., 1700., 1750., 1800., 1850., 1900., 1950.,
2000., 2050., 2100., 2150., 2200., 2250., 2300., 2350., 2400., 2450.,
2500., 2550., 2600., 2650., 2700., 2750., 2800., 2850., 2900., 2950.,
3000., 3050., 3100., 3150., 3200., 3250., 3300., 3350., 3400., 3450.,
3500., 3550., 3600., 3650., 3700., 3750., 3800., 3850., 3900., 3950.,
4000., 4050., 4100., 4150., 4200., 4250., 4300., 4350., 4400., 4450.,
4500., 4550., 4600., 4650., 4700., 4750., 4800., 4850., 4900., 4950.,
5000., 5050., 5100., 5150., 5200., 5250., 5300., 5350., 5400., 5450.,
5500., 5550., 5600., 5650., 5700., 5750., 5800., 5850., 5900., 5950.,
6000., 6050., 6100., 6150., 6200., 6250., 6300., 6350., 6400., 6450.,
6500., 6550., 6600., 6650., 6700., 6750., 6800., 6850., 6900., 6950.,
7000., 7050., 7100., 7150., 7200., 7250., 7300., 7350., 7400., 7450.,
7500., 7550., 7600., 7650., 7700., 7750., 7800., 7850., 7900., 7950.,
8000.])
I see, you are right, they behave differently if n_fft
is odd. Actually I've never used n_fft
odd, it's quite non standard and I suggest to avoid it, if possible. I think something like this may work for either even or odd n_fft
:
In [41]: import torch
...: import librosa
...:
...: # INPUT:
...: sr = 16000
...: n_fft = 17
...:
...: # PROCESS:
...: n_freqs = 1 + n_fft // 2
...:
...: # freqs_torch = torch.linspace(0, sr // 2, n_freqs)
...: freqs_torch = torch.linspace(0, (n_freqs-1) * sr / n_fft, n_freqs)
...:
...: freqs_librosa = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
In [42]: freqs_librosa
Out[42]:
array([ 0. , 941.17647059, 1882.35294118, 2823.52941176,
3764.70588235, 4705.88235294, 5647.05882353, 6588.23529412,
7529.41176471])
In [43]: freqs_torch
Out[43]:
tensor([ 0.0000, 941.1765, 1882.3529, 2823.5293, 3764.7058, 4705.8823,
5647.0586, 6588.2354, 7529.4116])
BTW, I don't work in this project.
Yes, you're right. freqs_torch = torch.linspace(0, (n_freqs-1) * sr / n_fft, n_freqs)
is basically what torch.fft.rfftfreq
is doing.