candle
candle copied to clipboard
inefficient implementation of gelu for fp16
I'm running the dinov2 example on CPU on a Cortex-A76 computer, except I've quantised it to fp16. Looking at its perf profile, a large subset is due to running scalar numeric operations.
Tracking this down, I found that this was due to the gelu implementation:
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f32_const(0.5)
* v
* (f16::ONE
+ f16::tanh(
(f16::from_f32_const(2.0) / f16::PI).sqrt()
* v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
))
}
It computes sqrt(2/pi) on every call, and over the totality of DinoV2, this represents 8.8% of cycles. I can't speak to whether this would be different for the f32 or f64 impls, but due to challenges of float math I cannot imagine that it is.
Interesting, I would have hoped that the compiler uses constant propagation so as not to recompute this on all call, but rustc has quite a few shortcomings when it comes to float const functions that may prevent this from happening. Would you mind trying to replace this with the hardcoded value for the thing and see if it has an impact? (and if that's the case, happy to get a PR for the change)
fn crit(c: &mut Criterion) {
let arg = 20.0;
c.bench_function("gelu_f16", |c| {
c.iter(|| Gelu::f16(black_box(f16::from_f32(arg))))
});
}
leads to:
current implementation: 2.2172ns below implementation: 1.6855ns
on my M1 mac.
I would expect that constant propagation is trickier on f16 as the compiler has to inline more to understand that it's safe.
This is the solution I used.
f16::from_f32_const((2.0 / PI).sqrt())
which does work with constant propgation. I can confirm that constant propagation does work on f32 and f64, validated on godbolt.
Neat, would be great if you can make a PR for this (ideally with a comment explaining why it's useful). If you don't have the time I could take a stab at it.
Made a PR for this, #2008