consistency_models icon indicating copy to clipboard operation
consistency_models copied to clipboard

Issue with `use_fp16=True` Leading to Type Conversion Error in `unet.py`

Open DuoLi1999 opened this issue 1 year ago • 1 comments

When setting use_fp16=False, the code functions correctly. However, an issue arises with use_fp16=True due to an unexpected type conversion in unet.py(line435).

The problem occurs at line 435, where the tensor a is converted from float16 to float32:

a = a.float()

Prior to this line, a is in float16, but after this line, it is converted to float32. If we remove or comment out this line, the code encounters an error. It seems that maintaining a in float16 is essential for the use_fp16=True setting to work correctly, but the current implementation inadvertently converts it to float32, leading to issues.

Additionally, I've noticed that the current code has been modified to prevent the utilization of flash attention. I also attempted to run the original version, but encountered similar errors.

DuoLi1999 avatar Dec 18 '23 11:12 DuoLi1999

Same question

songtianhui avatar Feb 08 '24 14:02 songtianhui