Inconsistent values of sum(abs, _) between GPU and CPU (NaNs for zero input only on GPU)
Bug description
I've experienced the following inconsistency between GPU and CPU gradient computation for sum(abs, _).
julia> using Zygote, CUDA
julia> rl, cplx = [0.0f0], [0.0f0 + 0.0f0im]
(Float32[0.0], ComplexF32[0.0f0 + 0.0f0im])
julia> l1(x) = sum(abs, x)
l1 (generic function with 1 method)
julia> Zygote.gradient(l1, rl)
(Float32[0.0],)
julia> Zygote.gradient(l1, cplx)
(ComplexF32[0.0f0 + 0.0f0im],)
julia> Zygote.gradient(l1, cu(rl))
(Float32[1.0],)
julia> Zygote.gradient(l1, cu(cplx))
(ComplexF32[NaN32 + NaN32*im],)
The last one is particularly problematic, as it leads to NaN values in the gradient that may be hard to understand in a more complex model.
Slack discussion
On Slack, @mcabbott explained to me the most likely cause for this:
- on GPU, in the backward pass Zygote converts
sum(abs, x)tosum(abs.(x))and the broadcasting part is differentiated via ForwardDiff - ForwardDiff is responsible for the different values of the gradient
julia> abs(ForwardDiff.Dual(0,1))
Dual{Nothing}(0,1)
julia> abs(ForwardDiff.Dual(0,1) + 0im)
Dual{Nothing}(0.0,NaN)
- even though DiffRules has a rule for
abs(used for real inputs), for complex inputs the computation passes throughhypotand the DiffRule method for the derivative ofhypotin(0, 0)givesNaN
Not sure what the best fix is here. If DiffRules is open to it, maybe the easiest is to fix their hypot derivative rule?
Version info
I'm on julia 1.10.5, on a fresh environment with
(jl_FHvUua) pkg> st
Status `/tmp/jl_FHvUua/Project.toml`
[052768ef] CUDA v5.5.2
[e88e6eb3] Zygote v0.6.71
Any chance that https://github.com/JuliaDiff/ForwardDiff.jl/pull/669 solves this?
Somehow it doesn't... Unless I messed something up, I checked out https://github.com/JuliaDiff/ForwardDiff.jl/pull/669 (manually changing version ForwardDiff version number to 0.10) and still get
julia> Zygote.gradient(l1, cu(cplx))
(ComplexF32[NaN32 + NaN32*im],)
which is weird, because indeed hypot differentiates just fine:
julia> f(x) = hypot(x, 0, 0)
f (generic function with 1 method)
julia> ForwardDiff.derivative(f, 0.0)
1.0
julia> ForwardDiff.derivative(f, -0.0)
-1.0