BLAS rule for nrm2 at [0.0] return NaN
https://github.com/EnzymeAD/Enzyme/blob/f513bdfc8c1fc09dea3f905c7ddcb9d9c91a45f7/enzyme/Enzyme/BlasDerivatives.td#L237-L243
compare with ChainRules
https://github.com/JuliaDiff/ChainRules.jl/blob/0923b1e7281b8d1418205a8798cc653c5cbd8552/src/rulesets/LinearAlgebra/blas.jl#L38-L47
And from Julia
using Enzyme
using LinearAlgebra
x = [5.0]
dx = [1.0]
autodiff(Forward, LinearAlgebra.BLAS.nrm2, Duplicated(x, dx)) # 0.447214
autodiff(Forward, LinearAlgebra.generic_norm2, Duplicated(x, dx)) # 1.0
Notably, the zero handling is also different.
dx = [1.0]
x = [0.0]
autodiff(Forward, LinearAlgebra.BLAS.nrm2, Duplicated(x, dx)) # NaN
autodiff(Forward, LinearAlgebra.generic_norm2, Duplicated(x, dx)) # 1.0 -- Same as ChainRules
I also used
function mynorm(X)
acc = zero(eltype(X))
for x in X
acc += x^2
end
return sqrt(acc)
end
as a sanitiy check and it agrees with generic_norm2 except for the x = [0.0] case
oh thats a silly bug, we should remove the bsqrt and it should work, presumably
The last remaining issue is that the handling of norm(x) == 0.0 currently leads to NaN
dx = [1.0]
x = [0.0]
autodiff(Forward, LinearAlgebra.BLAS.nrm2, Duplicated(x, dx)) # NaN
autodiff(Forward, LinearAlgebra.generic_norm2, Duplicated(x, dx)) # 1.0 -- Same as ChainRules