consistency_models
consistency_models copied to clipboard
Issue with `use_fp16=True` Leading to Type Conversion Error in `unet.py`
When setting use_fp16=False
, the code functions correctly. However, an issue arises with use_fp16=True
due to an unexpected type conversion in unet.py
(line435).
The problem occurs at line 435, where the tensor a
is converted from float16
to float32
:
a = a.float()
Prior to this line, a
is in float16
, but after this line, it is converted to float32
. If we remove or comment out this line, the code encounters an error. It seems that maintaining a
in float16
is essential for the use_fp16=True
setting to work correctly, but the current implementation inadvertently converts it to float32
, leading to issues.
Additionally, I've noticed that the current code has been modified to prevent the utilization of flash attention. I also attempted to run the original version, but encountered similar errors.
Same question