num-complex
num-complex copied to clipboard
Enable fused multiply-subtract in complex mul_add
Fused multiply-add was added for complex numbers in #37, based on the Julia implementation:
muladd(z::Complex, w::Complex, x::Complex) =
Complex(muladd(real(z), real(w), real(x)) - imag(z)*imag(w), # TODO: use mulsub given #15985
muladd(real(z), imag(w), muladd(imag(z), real(w), imag(x))))
The Julia code was later changed in https://github.com/JuliaLang/julia/pull/36140 to:
muladd(z::Complex, w::Complex, x::Complex) =
Complex(muladd(real(z), real(w), -muladd(imag(z), imag(w), -real(x))),
muladd(real(z), imag(w), muladd(imag(z), real(w), imag(x))))
The reason was that LLVM could optimize the negation to use fused multiply-subtract.
I tested this in Rust with the function:
fn mul_add_new<T: Clone + Num + MulAdd<Output = T> + Neg<Output = T>>(
this: Complex<T>,
other: Complex<T>,
add: Complex<T>,
) -> Complex<T> {
let re = this.re.clone().mul_add(
other.re.clone(),
-this.im.clone().mul_add(other.im.clone(), -add.re),
);
let im = this.re.mul_add(other.im, this.im.mul_add(other.re, add.im));
Complex::new(re, im)
}
When compiling with fma enabled in release mode, it is optimized in the same way to use fused multiply-subtract.
A potential drawback is that it requires an additional Neg bound. It is not possible to change the negation to subtraction from zero, since the optimization no longer works.