pytorch_wavelets icon indicating copy to clipboard operation
pytorch_wavelets copied to clipboard

A precision mismatch issue occurred when integrating this library with automatic mixed precision.

Open Sylence8 opened this issue 5 months ago • 2 comments

After enabling AMP, I used the DWTForward and DWTInverse classes provided by the library. This led to a situation in dwt/lowlevel.py where, for example, at line 356:lo = sfb1d(low, lh, h0_col, h1_col, mode=mode, dim=2) the tensor low has float16 precision, while the other inputs remain float32. This causes a type mismatch during backward, resulting in the following error:

 y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \
[rank0]: RuntimeError: expected scalar type Half but found Float

Is there a recommended way to resolve this issue, or is there any plan to update the library to address it?🤔

Sylence8 avatar Jul 04 '25 14:07 Sylence8