audio icon indicating copy to clipboard operation
audio copied to clipboard

Biquad functions not seemingly compatible with autocast with bfloat16

Open pokepress opened this issue 8 months ago • 0 comments

🐛 Describe the bug

I attempted to use highpass_biquad to calculate a loss while inside of an autocast block:

self.floatFormat = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
...
with torch.autocast(device_type="cuda", dtype=self.floatFormat):
...
                losses = self._get_losses(hr_reprs, pr_reprs)

def _get_losses(self, hr, pr):
...
                pr_highpass = torchaudio.functional.highpass_biquad(pr_time, self.args.experiment.hr_sr, self.args.experiment.hr_sr/10, 1.5)
                hr_highpass = torchaudio.functional.highpass_biquad(hr_time, self.args.experiment.hr_sr, self.args.experiment.hr_sr/10, 1.5)

However, it doesn't seem like it works if bfloat16 is in use:

  File "d:\...\.venv\lib\site-packages\torchaudio\functional\filtering.py", line 922, in highpass_biquad       
    return biquad(waveform, b0, b1, b2, a0, a1, a2)
  File "d:\...\.venv\lib\site-packages\torchaudio\functional\filtering.py", line 327, in biquad
    output_waveform = lfilter(
  File "d:\...\.venv\lib\site-packages\torchaudio\functional\filtering.py", line 1059, in lfilter
    output = _lfilter(waveform, a_coeffs, b_coeffs)
  File "d:\...\.venv\lib\site-packages\torch\_ops.py", line 1123, in __call__
    return self._op(*args, **(kwargs or {}))
RuntimeError: Expected (in.dtype() == torch::kFloat32 || in.dtype() == torch::kFloat64) && (a_flipped.dtype() == torch::kFloat32 || a_flipped.dtype() == torch::kFloat64) && (padded_out.dtype() == torch::kFloat32 || padded_out.dtype() == torch::kFloat64) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Basically, the code is complaining that the float wasn't in one of the formats it expected (32 or 64 bit). My understanding is that autocast is supposed to handle this itself, but for some reason it isn't. I did a little debugging and it looks like the code in filtering.py is treating the data as a 32-bit float in this case, so I can only assume the data isn't passed correctly to the native code(?) itself. I got around this by using float32 if this particular loss is active.

This is on 2.6.0/0.21.0, but I had the same problem in 2.4.1. Windows 10.

Versions

Couldn't get the last line of that code to work. Sorry.

pokepress avatar Feb 17 '25 02:02 pokepress