DiffSynth-Studio
DiffSynth-Studio copied to clipboard
H100 GPU - FP8 inference acceleration
Thanks for the great work !
I followed the Qwen-Image inference example in below link (recommended for faster inference) for comparing FP8 vs bfloat16: ./accelerate/Qwen-Image-FP8.py
But for FP8 (torch.float8_e4m3fn), the inference is slower than bfloat16.
FP8 takes 42s per image with a single-gpu instance and num_inference_steps=40 bfloat16 takes 33s under the same setting.
Setup:
- H100 GPU which supports FP8
- Flash Attention 3 is available
How to improve FP8's inference speed?
Hi, same problem here