audio icon indicating copy to clipboard operation
audio copied to clipboard

FFT frequency bins obtained by `torch.linsapce` in `torchaudio.functional.melscale_fbanks`

Open Emrys365 opened this issue 1 year ago • 4 comments

🐛 Describe the bug

I notice that in https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/functional.py#L561-L568,


    if norm is not None and norm != "slaney":
        raise ValueError('norm must be one of None or "slaney"')

    # freq bins
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

    # calculate mel freq bins

torch.linspace is used to obtain the FFT frequency bins. This is inconsistent with the implementation in librosa, which uses

    # Center freqs of each FFT bin
    fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft)

The above is equivalent to

np.fft.rfftfreq(n=n_fft, d=1.0 / sr)

The counterpart in PyTorch is

torch.fft.rfftfreq(n=n_fft, d=1.0 / sample_rate)

As a result, the current one will generate frequency bins with a different interval compared to that in librosa. For example:

expand
>>> sr = sample_rate = 16000
>>> n_fft = 321
>>> n_freqs = n_fft // 2 + 1
>>> all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
>>> print(all_freqs)
tensor([   0.,   50.,  100.,  150.,  200.,  250.,  300.,  350.,  400.,  450.,
         500.,  550.,  600.,  650.,  700.,  750.,  800.,  850.,  900.,  950.,
        1000., 1050., 1100., 1150., 1200., 1250., 1300., 1350., 1400., 1450.,
        1500., 1550., 1600., 1650., 1700., 1750., 1800., 1850., 1900., 1950.,
        2000., 2050., 2100., 2150., 2200., 2250., 2300., 2350., 2400., 2450.,
        2500., 2550., 2600., 2650., 2700., 2750., 2800., 2850., 2900., 2950.,
        3000., 3050., 3100., 3150., 3200., 3250., 3300., 3350., 3400., 3450.,
        3500., 3550., 3600., 3650., 3700., 3750., 3800., 3850., 3900., 3950.,
        4000., 4050., 4100., 4150., 4200., 4250., 4300., 4350., 4400., 4450.,
        4500., 4550., 4600., 4650., 4700., 4750., 4800., 4850., 4900., 4950.,
        5000., 5050., 5100., 5150., 5200., 5250., 5300., 5350., 5400., 5450.,
        5500., 5550., 5600., 5650., 5700., 5750., 5800., 5850., 5900., 5950.,
        6000., 6050., 6100., 6150., 6200., 6250., 6300., 6350., 6400., 6450.,
        6500., 6550., 6600., 6650., 6700., 6750., 6800., 6850., 6900., 6950.,
        7000., 7050., 7100., 7150., 7200., 7250., 7300., 7350., 7400., 7450.,
        7500., 7550., 7600., 7650., 7700., 7750., 7800., 7850., 7900., 7950.,
        8000.])

>>> print(np.fft.rfftfreq(n=n_fft, d=1.0 / sr))
[   0.           49.84423676   99.68847352  149.53271028  199.37694704
  249.2211838   299.06542056  348.90965732  398.75389408  448.59813084
  498.4423676   548.28660436  598.13084112  647.97507788  697.81931464
  747.6635514   797.50778816  847.35202492  897.19626168  947.04049844
  996.8847352  1046.72897196 1096.57320872 1146.41744548 1196.26168224
 1246.105919   1295.95015576 1345.79439252 1395.63862928 1445.48286604
 1495.3271028  1545.17133956 1595.01557632 1644.85981308 1694.70404984
 1744.5482866  1794.39252336 1844.23676012 1894.08099688 1943.92523364
 1993.7694704  2043.61370717 2093.45794393 2143.30218069 2193.14641745
 2242.99065421 2292.83489097 2342.67912773 2392.52336449 2442.36760125
 2492.21183801 2542.05607477 2591.90031153 2641.74454829 2691.58878505
 2741.43302181 2791.27725857 2841.12149533 2890.96573209 2940.80996885
 2990.65420561 3040.49844237 3090.34267913 3140.18691589 3190.03115265
 3239.87538941 3289.71962617 3339.56386293 3389.40809969 3439.25233645
 3489.09657321 3538.94080997 3588.78504673 3638.62928349 3688.47352025
 3738.31775701 3788.16199377 3838.00623053 3887.85046729 3937.69470405
 3987.53894081 4037.38317757 4087.22741433 4137.07165109 4186.91588785
 4236.76012461 4286.60436137 4336.44859813 4386.29283489 4436.13707165
 4485.98130841 4535.82554517 4585.66978193 4635.51401869 4685.35825545
 4735.20249221 4785.04672897 4834.89096573 4884.73520249 4934.57943925
 4984.42367601 5034.26791277 5084.11214953 5133.95638629 5183.80062305
 5233.64485981 5283.48909657 5333.33333333 5383.17757009 5433.02180685
 5482.86604361 5532.71028037 5582.55451713 5632.39875389 5682.24299065
 5732.08722741 5781.93146417 5831.77570093 5881.61993769 5931.46417445
 5981.30841121 6031.15264798 6080.99688474 6130.8411215  6180.68535826
 6230.52959502 6280.37383178 6330.21806854 6380.0623053  6429.90654206
 6479.75077882 6529.59501558 6579.43925234 6629.2834891  6679.12772586
 6728.97196262 6778.81619938 6828.66043614 6878.5046729  6928.34890966
 6978.19314642 7028.03738318 7077.88161994 7127.7258567  7177.57009346
 7227.41433022 7277.25856698 7327.10280374 7376.9470405  7426.79127726
 7476.63551402 7526.47975078 7576.32398754 7626.1682243  7676.01246106
 7725.85669782 7775.70093458 7825.54517134 7875.3894081  7925.23364486
 7975.07788162]

>>> print(torch.fft.rfftfreq(n=n_fft, d=1.0 / sample_rate))
tensor([   0.0000,   49.8442,   99.6885,  149.5327,  199.3770,  249.2212,
         299.0654,  348.9097,  398.7539,  448.5981,  498.4424,  548.2866,
         598.1309,  647.9751,  697.8193,  747.6636,  797.5078,  847.3521,
         897.1963,  947.0405,  996.8848, 1046.7290, 1096.5732, 1146.4175,
        1196.2617, 1246.1060, 1295.9502, 1345.7944, 1395.6387, 1445.4829,
        1495.3271, 1545.1714, 1595.0156, 1644.8599, 1694.7041, 1744.5483,
        1794.3926, 1844.2368, 1894.0811, 1943.9253, 1993.7695, 2043.6138,
        2093.4580, 2143.3022, 2193.1465, 2242.9907, 2292.8350, 2342.6792,
        2392.5234, 2442.3677, 2492.2119, 2542.0562, 2591.9004, 2641.7446,
        2691.5889, 2741.4331, 2791.2773, 2841.1216, 2890.9658, 2940.8101,
        2990.6543, 3040.4985, 3090.3428, 3140.1870, 3190.0312, 3239.8755,
        3289.7197, 3339.5640, 3389.4082, 3439.2524, 3489.0967, 3538.9409,
        3588.7852, 3638.6294, 3688.4736, 3738.3179, 3788.1621, 3838.0063,
        3887.8506, 3937.6948, 3987.5391, 4037.3833, 4087.2275, 4137.0718,
        4186.9160, 4236.7603, 4286.6045, 4336.4487, 4386.2930, 4436.1372,
        4485.9814, 4535.8257, 4585.6699, 4635.5142, 4685.3584, 4735.2026,
        4785.0469, 4834.8911, 4884.7354, 4934.5796, 4984.4238, 5034.2681,
        5084.1123, 5133.9565, 5183.8008, 5233.6450, 5283.4893, 5333.3335,
        5383.1777, 5433.0220, 5482.8662, 5532.7104, 5582.5547, 5632.3989,
        5682.2432, 5732.0874, 5781.9316, 5831.7759, 5881.6201, 5931.4644,
        5981.3086, 6031.1528, 6080.9971, 6130.8413, 6180.6855, 6230.5298,
        6280.3740, 6330.2183, 6380.0625, 6429.9067, 6479.7510, 6529.5952,
        6579.4395, 6629.2837, 6679.1279, 6728.9722, 6778.8164, 6828.6606,
        6878.5049, 6928.3491, 6978.1934, 7028.0376, 7077.8818, 7127.7261,
        7177.5703, 7227.4146, 7277.2588, 7327.1030, 7376.9473, 7426.7915,
        7476.6357, 7526.4800, 7576.3242, 7626.1685, 7676.0127, 7725.8569,
        7775.7012, 7825.5454, 7875.3896, 7925.2339, 7975.0781])

I also checked the impact of this line of code on the resultant mel-scale conversion matrix in torchaudio.functional.melscale_fbanks.

After replacing the current one with

    # freq bins
    all_freqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sample_rate)

the results are more consistent with librosa's output.

>>> sr = 16000
>>> n_fft = 321
>>> hop_length = 160
>>> n_mels = 120
>>> torch.random.manual_seed(0)
>>> x = torch.randn(16000)
>>> print(x)
tensor([-1.1258, -1.1524, -0.2506,  ...,  0.1347,  0.7971,  0.2949])

################################################################
# Original code
################################################################
>>> out1_th = torchaudio.transforms.MelSpectrogram(sample_rate=sr, hop_length=hop_length, n_fft=n_fft, n_mels=n_mels, pad_mode="constant", norm="slaney", mel_scale="slaney")(x)
>>> out1_ref = librosa.feature.melspectrogram(y=x.numpy(), sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
>>> np.testing.assert_allclose(out1_th.numpy(), out1_ref, atol=3e-5)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=3e-05

Mismatched elements: 11999 / 12000 (100%)
Max absolute difference: 3.4721794
Max relative difference: 12117.599
 x: array([[0.000000e+00, 0.000000e+00, 0.000000e+00, ..., 0.000000e+00,
        0.000000e+00, 0.000000e+00],
       [1.446359e+00, 1.172625e-01, 1.025011e+01, ..., 3.962148e+00,...
 y: array([[7.713332e-04, 6.253528e-05, 5.466319e-03, ..., 2.112987e-03,
        3.167379e-03, 2.537756e-03],
       [1.453905e+00, 1.178743e-01, 1.030360e+01, ..., 3.982821e+00,...


>>> out2_th = torchaudio.transforms.MelSpectrogram(sample_rate=sr, hop_length=hop_length, n_fft=n_fft, n_mels=n_mels, pad_mode="constant")(x)
>>> out2_ref = librosa.feature.melspectrogram(y=x.numpy(), sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, norm=None, htk=True)
>>> np.testing.assert_allclose(out2_th.numpy(), out2_ref, atol=3e-5)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=3e-05

Mismatched elements: 11400 / 12000 (95%)
Max absolute difference: 282.45605
Max relative difference: 2.562869
 x: array([[  0.      ,   0.      ,   0.      , ...,   0.      ,   0.      ,
          0.      ],
       [  0.      ,   0.      ,   0.      , ...,   0.      ,   0.      ,...
 y: array([[  0.      ,   0.      ,   0.      , ...,   0.      ,   0.      ,
          0.      ],
       [  0.      ,   0.      ,   0.      , ...,   0.      ,   0.      ,...


################################################################
# After replaced by ` torch.fft.rfftfreq`
################################################################
>>> out1_th = torchaudio.transforms.MelSpectrogram(sample_rate=sr, hop_length=hop_length, n_fft=n_fft, n_mels=n_mels, pad_mode="constant", norm="slaney", mel_scale="slaney")(x)
>>> out1_ref = librosa.feature.melspectrogram(y=x.numpy(), sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
>>> np.testing.assert_allclose(out1_th.numpy(), out1_ref, atol=3e-5)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=3e-05

Mismatched elements: 639 / 12000 (5.33%)
Max absolute difference: 0.00011086
Max relative difference: 0.00052703
 x: array([[7.710888e-04, 6.251546e-05, 5.464583e-03, ..., 2.112317e-03,
        3.166373e-03, 2.536952e-03],
       [1.453906e+00, 1.178743e-01, 1.030360e+01, ..., 3.982822e+00,...
 y: array([[7.713332e-04, 6.253528e-05, 5.466319e-03, ..., 2.112987e-03,
        3.167379e-03, 2.537756e-03],
       [1.453905e+00, 1.178743e-01, 1.030360e+01, ..., 3.982821e+00,...

>>> out2_th = torchaudio.transforms.MelSpectrogram(sample_rate=sr, hop_length=hop_length, n_fft=n_fft, n_mels=n_mels, pad_mode="constant")(x)
>>> out2_ref = librosa.feature.melspectrogram(y=x.numpy(), sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, norm=None, htk=True)
>>> np.testing.assert_allclose(out2_th.numpy(), out2_ref, atol=3e-5)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=3e-05

Mismatched elements: 9897 / 12000 (82.5%)
Max absolute difference: 0.02191162
Max relative difference: 0.00012979
 x: array([[  0.      ,   0.      ,   0.      , ...,   0.      ,   0.      ,
          0.      ],
       [  0.      ,   0.      ,   0.      , ...,   0.      ,   0.      ,...
 y: array([[  0.      ,   0.      ,   0.      , ...,   0.      ,   0.      ,
          0.      ],
       [  0.      ,   0.      ,   0.      , ...,   0.      ,   0.      ,...

Versions

PyTorch version: 2.0.1 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 12.6.4 (x86_64) GCC version: Could not collect Clang version: 14.0.0 (clang-1400.0.29.202) CMake version: version 3.27.0 Libc version: N/A

Python version: 3.9.17 (main, Jun 20 2023, 17:04:52) [Clang 14.0.0 (clang-1400.0.29.202)] (64-bit runtime) Python platform: macOS-12.6.4-x86_64-i386-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz

Versions of relevant libraries: [pip3] numpy==1.23.5 [pip3] torch==2.0.1 [pip3] torchaudio==2.0.2

Emrys365 avatar Oct 16 '23 23:10 Emrys365

I don't understand the issue. I've got the same result from the two functions you say are inconsistent:

In [19]: import torch
    ...: import librosa
    ...: 
    ...: # INPUT:
    ...: 
    ...: sr = 16000
    ...: n_fft = 16
    ...: 
    ...: # PROCESS:
    ...: n_freqs = 1 + n_fft // 2
    ...: 
    ...: freqs_torch = torch.linspace(0, sr // 2, n_freqs)
    ...: 
    ...: freqs_librosa = librosa.fft_frequencies(sr=sr, n_fft=n_fft)

In [20]: freqs_librosa
Out[20]: array([   0., 1000., 2000., 3000., 4000., 5000., 6000., 7000., 8000.])

In [21]: freqs_torch
Out[21]: tensor([   0., 1000., 2000., 3000., 4000., 5000., 6000., 7000., 8000.])

felipeespic avatar Oct 18 '23 00:10 felipeespic

Yes, they are the same in your example when n_fft = 16. But it's not always true. In my example, when n_fft=321 (as used in DNSMOS computation), they have different outputs:

In [1]: import torch
    ...: import librosa
    ...: 
    ...: # INPUT:
    ...: 
    ...: sr = 16000
    ...: n_fft = 16
    ...: 
    ...: # PROCESS:
    ...: n_freqs = 1 + n_fft // 2
    ...: 
    ...: freqs_torch = torch.linspace(0, sr // 2, n_freqs)
    ...: 
    ...: freqs_librosa = librosa.fft_frequencies(sr=sr, n_fft=n_fft)

In [2]: freqs_librosa
Out[2]:
array([   0.        ,   49.84423676,   99.68847352,  149.53271028,
        199.37694704,  249.2211838 ,  299.06542056,  348.90965732,
        398.75389408,  448.59813084,  498.4423676 ,  548.28660436,
        598.13084112,  647.97507788,  697.81931464,  747.6635514 ,
        797.50778816,  847.35202492,  897.19626168,  947.04049844,
        996.8847352 , 1046.72897196, 1096.57320872, 1146.41744548,
       1196.26168224, 1246.105919  , 1295.95015576, 1345.79439252,
       1395.63862928, 1445.48286604, 1495.3271028 , 1545.17133956,
       1595.01557632, 1644.85981308, 1694.70404984, 1744.5482866 ,
       1794.39252336, 1844.23676012, 1894.08099688, 1943.92523364,
       1993.7694704 , 2043.61370717, 2093.45794393, 2143.30218069,
       2193.14641745, 2242.99065421, 2292.83489097, 2342.67912773,
       2392.52336449, 2442.36760125, 2492.21183801, 2542.05607477,
       2591.90031153, 2641.74454829, 2691.58878505, 2741.43302181,
       2791.27725857, 2841.12149533, 2890.96573209, 2940.80996885,
       2990.65420561, 3040.49844237, 3090.34267913, 3140.18691589,
       3190.03115265, 3239.87538941, 3289.71962617, 3339.56386293,
       3389.40809969, 3439.25233645, 3489.09657321, 3538.94080997,
       3588.78504673, 3638.62928349, 3688.47352025, 3738.31775701,
       3788.16199377, 3838.00623053, 3887.85046729, 3937.69470405,
       3987.53894081, 4037.38317757, 4087.22741433, 4137.07165109,
       4186.91588785, 4236.76012461, 4286.60436137, 4336.44859813,
       4386.29283489, 4436.13707165, 4485.98130841, 4535.82554517,
       4585.66978193, 4635.51401869, 4685.35825545, 4735.20249221,
       4785.04672897, 4834.89096573, 4884.73520249, 4934.57943925,
       4984.42367601, 5034.26791277, 5084.11214953, 5133.95638629,
       5183.80062305, 5233.64485981, 5283.48909657, 5333.33333333,
       5383.17757009, 5433.02180685, 5482.86604361, 5532.71028037,
       5582.55451713, 5632.39875389, 5682.24299065, 5732.08722741,
       5781.93146417, 5831.77570093, 5881.61993769, 5931.46417445,
       5981.30841121, 6031.15264798, 6080.99688474, 6130.8411215 ,
       6180.68535826, 6230.52959502, 6280.37383178, 6330.21806854,
       6380.0623053 , 6429.90654206, 6479.75077882, 6529.59501558,
       6579.43925234, 6629.2834891 , 6679.12772586, 6728.97196262,
       6778.81619938, 6828.66043614, 6878.5046729 , 6928.34890966,
       6978.19314642, 7028.03738318, 7077.88161994, 7127.7258567 ,
       7177.57009346, 7227.41433022, 7277.25856698, 7327.10280374,
       7376.9470405 , 7426.79127726, 7476.63551402, 7526.47975078,
       7576.32398754, 7626.1682243 , 7676.01246106, 7725.85669782,
       7775.70093458, 7825.54517134, 7875.3894081 , 7925.23364486,
       7975.07788162])

In [3]: freqs_torch
Out[3]:
tensor([   0.,   50.,  100.,  150.,  200.,  250.,  300.,  350.,  400.,  450.,
         500.,  550.,  600.,  650.,  700.,  750.,  800.,  850.,  900.,  950.,
        1000., 1050., 1100., 1150., 1200., 1250., 1300., 1350., 1400., 1450.,
        1500., 1550., 1600., 1650., 1700., 1750., 1800., 1850., 1900., 1950.,
        2000., 2050., 2100., 2150., 2200., 2250., 2300., 2350., 2400., 2450.,
        2500., 2550., 2600., 2650., 2700., 2750., 2800., 2850., 2900., 2950.,
        3000., 3050., 3100., 3150., 3200., 3250., 3300., 3350., 3400., 3450.,
        3500., 3550., 3600., 3650., 3700., 3750., 3800., 3850., 3900., 3950.,
        4000., 4050., 4100., 4150., 4200., 4250., 4300., 4350., 4400., 4450.,
        4500., 4550., 4600., 4650., 4700., 4750., 4800., 4850., 4900., 4950.,
        5000., 5050., 5100., 5150., 5200., 5250., 5300., 5350., 5400., 5450.,
        5500., 5550., 5600., 5650., 5700., 5750., 5800., 5850., 5900., 5950.,
        6000., 6050., 6100., 6150., 6200., 6250., 6300., 6350., 6400., 6450.,
        6500., 6550., 6600., 6650., 6700., 6750., 6800., 6850., 6900., 6950.,
        7000., 7050., 7100., 7150., 7200., 7250., 7300., 7350., 7400., 7450.,
        7500., 7550., 7600., 7650., 7700., 7750., 7800., 7850., 7900., 7950.,
        8000.])

Emrys365 avatar Oct 18 '23 02:10 Emrys365

I see, you are right, they behave differently if n_fft is odd. Actually I've never used n_fft odd, it's quite non standard and I suggest to avoid it, if possible. I think something like this may work for either even or odd n_fft:

In [41]: import torch
    ...: import librosa
    ...: 
    ...: # INPUT:
    ...: sr = 16000
    ...: n_fft = 17
    ...: 
    ...: # PROCESS:
    ...: n_freqs = 1 + n_fft // 2
    ...: 
    ...: # freqs_torch = torch.linspace(0, sr // 2, n_freqs)
    ...: freqs_torch = torch.linspace(0, (n_freqs-1) * sr / n_fft, n_freqs)
    ...: 
    ...: freqs_librosa = librosa.fft_frequencies(sr=sr, n_fft=n_fft)

In [42]: freqs_librosa
Out[42]: 
array([   0.        ,  941.17647059, 1882.35294118, 2823.52941176,
       3764.70588235, 4705.88235294, 5647.05882353, 6588.23529412,
       7529.41176471])

In [43]: freqs_torch
Out[43]: 
tensor([   0.0000,  941.1765, 1882.3529, 2823.5293, 3764.7058, 4705.8823,
        5647.0586, 6588.2354, 7529.4116])

BTW, I don't work in this project.

felipeespic avatar Oct 18 '23 09:10 felipeespic

Yes, you're right. freqs_torch = torch.linspace(0, (n_freqs-1) * sr / n_fft, n_freqs) is basically what torch.fft.rfftfreq is doing.

Emrys365 avatar Oct 18 '23 13:10 Emrys365