ComfyUI icon indicating copy to clipboard operation
ComfyUI copied to clipboard

Enable FP8 on rocm for gfx12

Open ShadowElf37 opened this issue 10 months ago • 3 comments

Feature Idea

gfx12 (9060, 9070) brings support for fp8 on AMD consumer cards. This functionality is currently soldered off, I think by comfy/model_management.py, if not other parts of the codebase:

def supports_fp8_compute(device=None):
    if not is_nvidia():
        return  False

fp8 should be enabled on these platforms, and other AMD cards that support it.

Existing Solutions

No response

Other

No response

ShadowElf37 avatar May 22 '25 16:05 ShadowElf37

You can now use --supports-fp8-compute to make that function return True.

I'm very curious if using that argument + selecting the fp8_e4m3fn_fast dtype in the Load Diffusion Model node actually works and gives a speed boost on the 9070.

comfyanonymous avatar May 23 '25 21:05 comfyanonymous

You can now use --supports-fp8-compute to make that function return True.

I'm very curious if using that argument + selecting the fp8_e4m3fn_fast dtype in the Load Diffusion Model node actually works and gives a speed boost on the 9070.

i tried just that and for some reason it upcasts it to f16

harakiru avatar May 24 '25 00:05 harakiru

I am seeing the same behavior: model weight dtype torch.float8_e4m3fn, manual cast: torch.bfloat16

ShadowElf37 avatar May 24 '25 15:05 ShadowElf37

I'm very curious if using that argument + selecting the fp8_e4m3fn_fast dtype in the Load Diffusion Model node actually works and gives a speed boost on the 9070.

Yes, fp8 ops are supported and are faster.

ROCm 6.4.1 using this Comfy version https://github.com/comfyanonymous/ComfyUI/pull/8289

10 iters / flux1-dev-fp8.safetensors / model weight dtype torch.float8_e4m3fn, manual cast: torch.bfloat16

Command Run Sampling Total Prompt Time
--use-pytorch-cross-attention --fast 1st Run 4.84s/it 68.22 sec
--use-pytorch-cross-attention --fast 2nd Run 3.27s/it 37.50 sec
--use-pytorch-cross-attention --fast --supports-fp8-compute 1st Run 1.87s/it 34.41 sec
--use-pytorch-cross-attention --fast --supports-fp8-compute 2nd Run 1.86s/it 21.84 sec

kasper93 avatar May 29 '25 01:05 kasper93

For people who do get a speedup can you post your full ComfyUI log so I can see which pytorch version, arch, ROCm, etc.. you are using so I can enable it by default.

The manual cast bfloat16 is for the other non fp8 ops so you can ignore it.

comfyanonymous avatar May 30 '25 21:05 comfyanonymous

I have tested on ROCm 6.4.1, gfx1201 (9070xt). On Windows, WSL2 and Linux with similar results.

Windows with pytorch 2.7.0a0+git3f903c3 WSL2 and Linux with pytorch 2.6.0+rocm6.4.1.git1ded221d

As for FP8 support, it seems to be supported on gfx942, gfx950, gfx1201, you could copy this logic for detection. https://github.com/ROCm/pytorch/blob/0a6e1d6b9bf78d690a812e4334939e7701bfa794/torch/testing/_internal/common_cuda.py#L89-L103

I don't have access to MI hardware, but I believe it should work since it's their enterprise offerings.

kasper93 avatar Jun 06 '25 16:06 kasper93

https://github.com/comfyanonymous/ComfyUI/commit/97755eed46ccb797cb14a692a4c2931ebf3ad60c

This should enable it by default for gfx1201,

comfyanonymous avatar Jun 08 '25 18:06 comfyanonymous

dammit, that is with ROCm 7.0 Final, What is going on now?

"Exception during fp8 op: Float8_e4m3fn is only supported for ROCm 6.5 and above"

B4rr3l-Rid3r avatar Sep 16 '25 12:09 B4rr3l-Rid3r