flux
flux copied to clipboard
Inference on TPUs instead of GPUs.
Hi folks! Our AI Hypercomputer team ported Flux inference implementation to MaxDiffusion and were able to successfully run both Flux-dev and Flux-schnell models using Google's TPUs.
Running tests on 1024 x 1024 images with flash attention and bfloat16 gave the following results:
| Model | Accelerator | Sharding Strategy | Batch Size | Steps | time (secs) |
|---|---|---|---|---|---|
| Flux-dev | v4-8 | DDP | 4 | 28 | 23 |
| Flux-schnell | v4-8 | DDP | 4 | 4 | 2.2 |
| Flux-dev | v6e-4 | DDP | 4 | 28 | 5.5 |
| Flux-schnell | v6e-4 | DDP | 4 | 4 | 0.8 |
| Flux-schnell | v6e-4 | FSDP | 4 | 4 | 1.2 |
We'd appreciate if you could give us some feedback on the above-mentioned results and our overall approach.