Enable FP8 on rocm for gfx12
Feature Idea
gfx12 (9060, 9070) brings support for fp8 on AMD consumer cards. This functionality is currently soldered off, I think by comfy/model_management.py, if not other parts of the codebase:
def supports_fp8_compute(device=None):
if not is_nvidia():
return False
fp8 should be enabled on these platforms, and other AMD cards that support it.
Existing Solutions
No response
Other
No response
You can now use --supports-fp8-compute to make that function return True.
I'm very curious if using that argument + selecting the fp8_e4m3fn_fast dtype in the Load Diffusion Model node actually works and gives a speed boost on the 9070.
You can now use --supports-fp8-compute to make that function return True.
I'm very curious if using that argument + selecting the fp8_e4m3fn_fast dtype in the Load Diffusion Model node actually works and gives a speed boost on the 9070.
i tried just that and for some reason it upcasts it to f16
I am seeing the same behavior: model weight dtype torch.float8_e4m3fn, manual cast: torch.bfloat16
I'm very curious if using that argument + selecting the fp8_e4m3fn_fast dtype in the Load Diffusion Model node actually works and gives a speed boost on the 9070.
Yes, fp8 ops are supported and are faster.
ROCm 6.4.1 using this Comfy version https://github.com/comfyanonymous/ComfyUI/pull/8289
10 iters / flux1-dev-fp8.safetensors / model weight dtype torch.float8_e4m3fn, manual cast: torch.bfloat16
| Command | Run | Sampling | Total Prompt Time |
|---|---|---|---|
--use-pytorch-cross-attention --fast |
1st Run | 4.84s/it | 68.22 sec |
--use-pytorch-cross-attention --fast |
2nd Run | 3.27s/it | 37.50 sec |
--use-pytorch-cross-attention --fast --supports-fp8-compute |
1st Run | 1.87s/it | 34.41 sec |
--use-pytorch-cross-attention --fast --supports-fp8-compute |
2nd Run | 1.86s/it | 21.84 sec |
For people who do get a speedup can you post your full ComfyUI log so I can see which pytorch version, arch, ROCm, etc.. you are using so I can enable it by default.
The manual cast bfloat16 is for the other non fp8 ops so you can ignore it.
I have tested on ROCm 6.4.1, gfx1201 (9070xt). On Windows, WSL2 and Linux with similar results.
Windows with pytorch 2.7.0a0+git3f903c3 WSL2 and Linux with pytorch 2.6.0+rocm6.4.1.git1ded221d
As for FP8 support, it seems to be supported on gfx942, gfx950, gfx1201, you could copy this logic for detection. https://github.com/ROCm/pytorch/blob/0a6e1d6b9bf78d690a812e4334939e7701bfa794/torch/testing/_internal/common_cuda.py#L89-L103
I don't have access to MI hardware, but I believe it should work since it's their enterprise offerings.
https://github.com/comfyanonymous/ComfyUI/commit/97755eed46ccb797cb14a692a4c2931ebf3ad60c
This should enable it by default for gfx1201,
dammit, that is with ROCm 7.0 Final, What is going on now?
"Exception during fp8 op: Float8_e4m3fn is only supported for ROCm 6.5 and above"