audio icon indicating copy to clipboard operation
audio copied to clipboard

Implement L-BFGS-B optimizer and update InverseMelScale

Open nateanl opened this issue 2 years ago • 10 comments

🚀 The feature

To increase the speed of InverseMelScale module, the SGD optimization can be replace with torch.linalg.lstsq.

Motivation, pitch

The current InverseMelScale module applies SGD optimizer and estimate the spectrogram in a for loop. The speed is much slower than librosa (See https://github.com/pytorch/audio/issues/2594) which uses non-negative least squares (nnls) algorithm. Another issue is the module can not be run in torch's inference mode (See https://github.com/pytorch/audio/issues/1902), because the gradient is disabled in inference mode, which blocks the SGD optimization.

The speed of torch.linalg.lstsq is much faster than SGD. Another benefit is it can be run without gradients, hence it enables the module to run in torch inference mode.

Alternatives

Another solution is using the same L-BFGS-B optimizer as in librosa. However, there is no implementation in PyTorch yet.

Additional context

No response

nateanl avatar Aug 23 '22 18:08 nateanl

One of the aspects brought up during the introduction of InverseMelScale was numerical stability. https://github.com/pytorch/audio/pull/366

What's the stability of torch.linalg.lstsq?

mthrok avatar Aug 23 '22 23:08 mthrok

According to torch.linalg.lstsq's doc page, it mentioned:

This function computes X = A.pinverse() @ B in a faster and more numerically stable way than performing the
computations separately.

It should be more stable than SVD, which is used in pinverse.

nateanl avatar Aug 23 '22 23:08 nateanl

The issue of torch.linalg.lstsq is that there is no non-negative constraint, hence the generated spectrogram is not global optimal. The best solution should be implementing L-BFGS-B optimizer in PyTorch, which is faster than internal optimization, and it can be run in inference mode.

nateanl avatar Aug 30 '22 14:08 nateanl

Seems like this is SciPy implementation.

https://github.com/scipy/scipy/blob/v1.9.3/scipy/optimize/_lbfgsb_py.py#L210-L399

mthrok avatar Dec 04 '22 02:12 mthrok

@xiaohui-zhang Interested?

mthrok avatar Dec 13 '22 19:12 mthrok

I have been using an implementation where I compute the pseudo-inverse of the filterbanks. It seems to work nicely and it gives similar results as lstsq. An upside of the pseudo-inverse is that we can cache the inverse filterbanks and speed up the computation.

A downside of this approach is that it introduces some noticable artifacts, though even my tests with librosa still show that the artifacts are present.

Here is the algorithm:

from typing import Optional 

import torch
from einops import rearrange, repeat
from torch import Tensor, nn
from torchaudio import functional as F

class InverseMelScale(nn.Module):
    def __init__(
        self,
        sample_rate: int,
        n_fft: int,
        n_mels: int,
        f_min: float = 0.0,
        f_max: Optional[float] = None,
        norm: Optional[str] = None,
        mel_scale: str = "htk"
    ) -> None: 
        super().__init__()

        # Compute the inverse filter banks using the pseudo inverse
        f_max = f_max or float(sample_rate // 2)
        fb = F.melscale_fbanks(
            (n_fft // 2 + 1), f_min, f_max, n_mels, sample_rate, norm, mel_scale
        )
        # Using pseudo-inverse is faster than calculating the least-squares in each
        # forward pass and experiments show that they converge to the same solution
        self.register_buffer("fb", torch.linalg.pinv(fb))

    def forward(self, melspec: Tensor) -> Tensor:
        # Flatten the melspec except for the frequency and time dimension
        shape = melspec.shape
        melspec = rearrange(melspec, "... f t -> (...) f t")

        # Expand the filter banks to match the melspec
        fb = repeat(self.fb, "f m -> n m f", n=melspec.shape[0])

        # Sythesize the stft specgram using the filter banks
        specgram = fb @ melspec
        # Ensure non-negative solution
        specgram = torch.clamp(specgram, min=0.)

        # Unflatten the specgram (*, freq, time)
        specgram = specgram.view(shape[:-2] + (fb.shape[-2], shape[-1]))

        return specgram

Kinyugo avatar Dec 14 '22 15:12 Kinyugo

@nateanl How does this look like?

mthrok avatar Jan 10 '23 18:01 mthrok

Hi @Kinyugo, thanks for the proposal. torch.linalg.pinv calls torch.lstsq which doesn't guarantee the output is non-negative, hence the inversed output is still not optimal. We need to use an non-negative optimization algorithm and get rid of torch.clamp for the optimal solution.

nateanl avatar Jan 10 '23 18:01 nateanl

Hi @nateanl, using L-BFGS-B will be a great solution. I am not sure if there is a specific implementation of it in PyTorch. It could be useful to run tests with the proposal and current solution that uses gradient descent. If the results are at-least comparable, the proposal could work as a temporary solution. I find that the GD one takes a long time to run especially on CPU.

Kinyugo avatar Jan 10 '23 19:01 Kinyugo

With following setups:

  • 44100Hz, 16bit, 2m 16s .wav file
  • Ryzen 5800X
  • n_fft = 2048 / n_mel = 128

Approx. Result:

  • @Kinyugo's sample InverseMelScale: 5s
  • librosa.feature.inverse.mel_to_audio: 10s
  • torchaudio.transforms.InverseMelScale: Ctrl+C after 1 hour, didn't complete

Comparison with Source / Librosa reconstruction / torchaudio reconstruction image

I do agree there's significant artifacts audible & visible in resulting spectrum - but if stock method doesn't complete in meaningful time, I personally think it's nice to have alternative 'worse yet faster' option in torchaudio.

Is there any updates or plan about this issue?

jupiterbjy avatar Apr 18 '23 18:04 jupiterbjy