DiT icon indicating copy to clipboard operation
DiT copied to clipboard

[Question] Why DiT-XL/2 takes 119 GFlops to generate 256x256 images?

Open void-main opened this issue 11 months ago • 1 comments

Hi guys, I wonder why it takes 119 GFlops for DiT-XL/2 to generate 256x256 images. According to my calculation, it should be over 228 GFlops, can anyone please kindly point out where am I wrong? Thanks.

Let's consider only the left part of the DiT blocks, and ignore the MLP part (since it only add 2 * params(MLP) flops), and here's the math:

  • just like other transformers, the parameters of the DiT transformers should be: layers * 12 * hidden_size * hidden_size = 28 * 12 * 1152 * 1152 = 445906944, which is 445M
  • for 256x256 images, after VAE, the latent tensor should be [32, 32, 4], with after patchify (p=2), the tensor should be [32 / 2, 32 / 2, 1152], which is 256 tokens with hidden_size=1152
  • like other transformers, the total flops should be seq_len * 2 * params, where seq_len = 256, params = 445M, and 2 stands for multiply and addition in matmul, and the result should be 256 * 2 * 445906944 = 228 GFlops instead of 119 GFlops.

And the result (228GFlops) is roughly 2 times of 119GFlops, may I ask where am I missing? Thanks again for any help!

void-main avatar Feb 26 '24 13:02 void-main

@wpeebles @s9xie could you please kindly take a look, thank you very much!

void-main avatar Feb 26 '24 13:02 void-main

Issue #14 may help you. @void-main

ictzyqq avatar Feb 27 '24 07:02 ictzyqq

Thanks @ictzyqq , it indeed solved my question.

In short, the paper counts flops as MACs, so I should remove the 2 (multiply and addition) from seq_len * 2 * params formula.

void-main avatar Feb 27 '24 08:02 void-main