audio icon indicating copy to clipboard operation
audio copied to clipboard

[Migration] Torchaudio Complex Tensor Support and Migration

Open mthrok opened this issue 3 years ago • 4 comments

Torchaudio Complex Tensor Support and Migration

Overview

torchaudio has been expressing complex numbers by having an extra dimension for real-part and imaginary-part. (We will refer this format as "pseudo complex type")

waveform = torch.randn(...)
spectrogram = torchaudio.functional.spectrogram(waveform, ..., power=None)

# the last dimension represents real and imaginary parts of complex tensors.
print(spectrogram.dtype, spectrogram.shape)
>>> torch.float32, torch.Size([... ,2])

PyTorch 1.6 introduced complex Tensor type, such as torch.complex64 (torch.cfloat) and torch.complex128 (torch.cdouble). (Will be refered as "native complex type")

The natitve complex type comes with handy methods for complex operation such as abs, angle and magphase. (Please refer to the official documentation for the detail.)

Over the few coming releases, we plan to migrate torchaudio's functions and transforms to the native complex type. This issues describes the planned approaches/works/changes/timeline. If you have a question, a concern or a suggestion. Feel free to leave a comment.

Migration Stages

We will perform the migration in multiple stages. At this moment, the completion of the later migration stages re not tied with specific releases yet.

✅ Stage 0 (~ 0.8)

Up to release 0.8, torchaudio exclusively used pseudo complex type. In PyTorch 1.7, PyTorch started the adaptation of native complex type and the migration of torch.fft namespace. Because of this, torchaudio already uses native complex type in some implementations (F.vad, T.Vad, kaldi.spectrogram and kaldi.fbank) but all the user facing APIs use use pseudo complex type.

✅ Stage 1 Add support for native complex type and deprecate pseudo complex type)

Completed: PyTorch 1.9 / torchaudio 0.9

Library code change

In this stage, torchaudio will support both pseudo complex type and native complex type. This means that

  • Functions that accept complex input should be able to handle both pseudo and native complex types.
  • Functions that return complex values can return both pseudo and native complex types.
    • For real-to-complex functions, a new argument return_complex will be added so that users can switch the behavior.

Test code update

In addition to the above library code changes, we are going to add a set of tests to make sure that native complex types work in common use cases. This includes;

  • Gradient check
  • JIT
  • Performance
  • Distributed training
  • nn.Module compatibility

✅Stage 2 (Switch to native complex type by default)

Completed: main branch. To be released as part of PyTorch 1.10 / torchaudio 0.10

The default value for return_complex is changed to True.

👉 Stage 3 (Remove the support for pseudo complex type)

In this stage, we will remove the support for pseudo complex type.

  • Functions that work with complex types should handle native complex types exclusively.
  • Passing pseudo complex type results in an error.
  • return_complex argument added in Stage. 1 is deprecated and eventually removed.

Affected Functions

The following figure illustrates the functions that handle complex values and their dependencies.

Screen Shot 2021-03-02 at 10 59 30

Utility functions

F.angle, F.complex_norm, T.ComplexNorm, F.magphase

These functions are deprecated in Stage.1 and will be removed in Stage.3. For F.angle, native complex tensors provide the angle() function. For F.complex_norm / T.ComplexNorm, the equivalent computation can be performed with abs().pow(n). F.magphase is a convenient function to call F.angle and F.complex_norm, therefore, this function is deprecated as well.

Real to real functions

F.griffinlim, T.GriffinLim

Changes to these functions are kept internal, therefore we can simply change the internals without disturbing the downstream users.

Complex to complex functions

F.phase_vocoder, T.TimeStretch

When adding support for native complex type, we can simplify the interface change as follow

  • If the input is pseudo complex type, return pseudo complex type
  • If the input is native complex type, return native complex type

Real to complex functions

F.spectrogram, T.Spectrogram

These functions return either real valued Tensor (power, energy) or complex valued Tensor (frequency representation), which depends on what power argument was provided. When power is not provided, these functions return a complex-valued Tensor. In this case, users have the option to receive the result in pseudo complex type or native complex type. return_complex argument will be added for this choice. If return_complex is True, then native complex type is returned. See #1009 for the discussion.

Timeline

Migration Phase 1 2 3
PyTorch/torchaudio versions 1.9 / 0.9 1.10 / 0.10 TBD
Class / function Type
F.angle, F.complex_norm, F.magphase, T.ComplexNorm C->R, utility Deprecated Deprecated Removed
F.griffinlim, T.GriffinLim R->R Adopts native complex internally No change
F.phase_vocoder, T.TimeStretch C->C
  • Support for native complex type is added
  • Support for pseudo complex type is deprecated.
  • The function returns the same type as the input.
    (native for native, pseudo for pseudo)
Support for pseudo complex type is removed.
Only handles native complex type.
F.spectrogram, T.Spectrogram R->C (when power=None) Argument return_complex is added. (default value is False)
When the return value is complex-valued (power=None),
the type of the returned Tensor can be switched with return_complex.
The default value of return_complex is changed to True. The return_complex argument is deprecated.

Migration steps

F.angle, F.complex_norm, F.magphase and T.ComplexNorm

~0.8 0.9~

spectrogram = ...  # Tensor with pseudo complex type (shape == (..., 2))
angle = F.angle(spectrogram)
magnitude = F.complex_norm(spectrogram, norm=1)
power = F.complex_norm(spectrogram, norm=2)
norm = F.complex_norm(spectrogram, norm=norm)
magnitude, phase = F.magphase(spectrogram, n)
spectrogram = ...  # Tensor with pseudo complex type (shape == (..., 2))
spectrogram = torch.view_as_complex(spectrogram)  
angle = spectrogram.angle()
magnitude = spectrogram.abs()
power = spectrogram.abs().pow(2)
norm = spectrogram.abs().pow(norm)
magnitude, phase = spectrogram.abs().pow(n), spectrogram.angle()

F.phase_vocoder, T.TimeStretch

~0.8 0.9~
spec = ... # pseudo complex (..., 2) 

## If using functional form
spec = F.phase_vocoder(spec, ...)
## else using transform
transform = T.TimeStretch(...)
spec = transform(spec)
# Convert to native complex
spec = ... # pseudo complex (..., 2) 
# convert to native complex type
spec = torch.view_as_complex(spec)

# Perform the operation
## If using functional form
spec = F.phase_vocoder(spec, ...)
## else using transform
transform = T.TimeStretch(...)
spec = transform(spec)

# Convert back to pseudo complex type
# (If your downstream code still expects pseudo complex type)
spec = torch.view_as_real(spec)

F.spectrogram, T.Spectrogram

~0.8 0.9~
spec = F.spectrogram(waveform, ..., power=None)

transform = T.Spectrogram(..., power=None)
spec = transform(waveform)  # pseudo complex (..., 2) 
spec = F.spectrogram(waveform, ..., power=None, return_complex=True)

transform = T.Spectrogram(..., power=None, return_complex=True)
spec = transform(waveform)  # native complex
# If your downstream code still expects pseudo complex type
spec = torch.view_as_real(spec)

PRs - TODO (@mthrok)

Migration

Phase 1

Code Change
  • [x] F.griffinlim #1368
  • [x] F.phase_vocoder, T.TimeStretch #1410 ~#758~
  • [x] F.spectrogram, T.Spectrogram, ~T.MelSpectrogram~ #1366 ~#1009~
Add deprecation Warnings
  • [x] F.angle, F.complex_norm, T.ComplexNorm, F.phase_vocoder, T.TimeStretch, F.spectrogram, T.Spectrogram #1445 ~#1431~
  • [x] F.magphase #1492

Phase 2

Change the default value of return_complex to True.
  • [x] F.spectrogram, T.Spectrogram #1549
Update the deprecation warnings to indicate the version of removal.
  • [x] F.angle, F.complex_norm, F.magphase, F.phase_vocoder, T.TimeStretch, T.ComplexNorm #1553

Phase 3

Remove the support for pseudo complex type.
  • [x] Remove F.magphase #1934
  • [x] Remove F.angle #1935
  • [x] Remove F.complex_norm and T.ComplexNorm #1942
  • [x] Remove support for pseudo complex type from F.phase_vocoder and T.TimeStretch #1957
  • [x] Remove support for pseudo complex type from F.spectrogram and T.Spectrogram #1958
    • [x] Deprecate return_complex argument from F.spectrogram and T.Spectrogram

Surrounding works

Conjugate input tests

  • [ ] After https://github.com/pytorch/pytorch/pull/54987, check if things work when the input complex Tensor is conjugate.

Autograd tests

  • [x] #1421 T.GriffinLim
    • [ ] [Optional] F.griffinlim
  • [x] #1420 T.TimeStretch, F.phase_vocoder NOTE They are not differentiable around zero.
  • [x] #1340 T.Spectrogram, T.MelSpectrogram
    • [x] Investigate why T.Spectrogram is not deterministic on its backward pass. [error -> issue pytorch/#54093]
    • [ ] [Optional] F.spectrogram

Ensuring TorchScript support

Check if all the functionals/transforms are covered by TorchScript consistency test and add if missing

  • [x] #1446 Add step to dump TorchScripted object in TorchScript test.

Benchmark

  • F.phase_vocoder #1410 x6 speed up on CPU. Not much different on GPU

cc @anjali411 @vincentqb

mthrok avatar Mar 02 '21 16:03 mthrok

Please, before going forward with the deprecation, note that complex32 format is poorly supported on cuda. Cannot do complex32 product, stft and istft allows power2 kernels only and a large list of drawbacks.

Had to rewrite whole code back to pseudo complex to be able to work with half precision...

JuanFMontesinos avatar Mar 03 '22 17:03 JuanFMontesinos

Cannot do complex32 product, stft and istft allows power2 kernels only and a large list of drawbacks.

Hi @JuanFMontesinos

Thanks for letting us know.

TorchAudio is not tested on fp16 nor complex32, and they are not part of officially supported types. So we did not realize that pseudo complex can be used for a workaround of complex32.

Unfortunately, we wrapped up the release v0.11 (scheduled to be out in about one week) and PyTorch removed complex32 type and torchaudio removed the support for pseudo complex type. So I assume it will be unusable to you. I can try reverting the pseudo complex support if that's the best course of action. (however it is known that some operations with pseudo complex have issues with accuracy, so it's not the best workaround, which is one we wanted to migrate to native complex type.)

The treatment of complex32 indeed needs improvement and there is an issue created for this in PyTorch. https://github.com/pytorch/pytorch/issues/71680 The most ideal outcome is that PyTorch core adds complex32 support quickly but I am not sure if that can happen quickly. However, your voice matters a lot here, so would you be willing to provide features that will be most relevant for you in https://github.com/pytorch/pytorch/issues/71680? That way, the PyTorch core team can prioritize it if they decide to work on them.

mthrok avatar Mar 05 '22 15:03 mthrok

There were some typos here, so the migration made the program crash.

AttributeError: 'Tensor' object has no attribute 'power'

The correct version should be

power = spectrogram.abs().pow(2)
norm = spectrogram.abs().pow(norm)
magnitude, phase = spectrogram.abs().pow(n), spectrogram.angle()

jeffeuxMartin avatar Mar 15 '22 13:03 jeffeuxMartin

@jeffeuxMartin thanks for the report. Fixed it.

mthrok avatar Mar 15 '22 16:03 mthrok