pytorch_wavelets
pytorch_wavelets copied to clipboard
A precision mismatch issue occurred when integrating this library with automatic mixed precision.
After enabling AMP, I used the DWTForward and DWTInverse classes provided by the library.
This led to a situation in dwt/lowlevel.py where, for example, at line 356:lo = sfb1d(low, lh, h0_col, h1_col, mode=mode, dim=2)
the tensor low has float16 precision, while the other inputs remain float32.
This causes a type mismatch during backward, resulting in the following error:
y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \
[rank0]: RuntimeError: expected scalar type Half but found Float
Is there a recommended way to resolve this issue, or is there any plan to update the library to address it?🤔