flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

layernorm/rmsnorm is slow

Open pillow37 opened this issue 1 year ago • 1 comments
trafficstars

Hi, I use layernorm and rmsnorm in my training pipeline on an A100 and observed via the pytorch profiler that these functions were quite slow. E.g. I measured via time.time() just for the rmsnorm:

  • average time using from flash_attn.ops.triton.layer_norm import RMSNorm: 0.00023508071899414062
  • average time using a naive RMSNorm implementation via pytorch: 6.4849853515625e-05

The profiler also indicated that there was work done on the CPU, which was somewhat confusing to me.

Do you know what the issue could be?

pillow37 avatar Jul 24 '24 19:07 pillow37

https://pytorch.org/tutorials/recipes/recipes/benchmark.html Please dont use time.time()

tridao avatar Jul 24 '24 20:07 tridao