flash-attention
flash-attention copied to clipboard
layernorm/rmsnorm is slow
trafficstars
Hi, I use layernorm and rmsnorm in my training pipeline on an A100 and observed via the pytorch profiler that these functions were quite slow. E.g. I measured via time.time() just for the rmsnorm:
- average time using from flash_attn.ops.triton.layer_norm import RMSNorm: 0.00023508071899414062
- average time using a naive RMSNorm implementation via pytorch: 6.4849853515625e-05
The profiler also indicated that there was work done on the CPU, which was somewhat confusing to me.
Do you know what the issue could be?
https://pytorch.org/tutorials/recipes/recipes/benchmark.html Please dont use time.time()