ChainRules.jl
ChainRules.jl copied to clipboard
Current rules for sqrt produce NaN for zero primal and (co)tangents
This only happens when the (co)tangent is 0.
julia> using ChainRules
julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.0), sqrt, 0.0)
(0.0, NaN)
julia> ChainRules.rrule(sqrt, 0.0)[2](0.0)
(ChainRulesCore.NoTangent(), NaN)
I suggest we adopt the convention that the produced (co)tangent in this case should also be 0. This is supported by finite differerences:
julia> using FiniteDifferences
julia> jvp(central_fdm(5, 1), sqrt, (0.0, 0.0))
0.0
julia> j′vp(central_fdm(5, 1), x -> sqrt(clamp(x, 0, Inf)), 0.0, 0.0)
(0.0,)
julia> j′vp(central_fdm(5, 1), sqrt ∘ abs, 0.0, 0.0)
(0.0,)
So instead of using @scalar_rule we would explicitly define the frule and rrule.
So the proposal is to always treat zero tangent (or cotangent) as a strong zero. I think that makes sense.
The rule for x^p already treats Δp being zero strongly, here:
https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/fastmath_able.jl#L171
The logic was that quite often the p is really a constant, and its zero would then turn some otherwise correct infinite derivatives with respect to x into NaNs.
x^0.5 behaves just like sqrt(x) with respect to Δx:
julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.01, 0.0), ^, 0.0, 0.5)
(0.0, Inf)
julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.0, 0.0), ^, 0.0, 0.5)
(0.0, NaN)
julia> ChainRules.rrule(^, 0.0, 0.5)[2](0.01)[2]
Inf
julia> ChainRules.rrule(^, 0.0, 0.5)[2](0.0)[2]
NaN
Treating zero tangent as strong could be done globally in @scalar_rule. Which would mean one more ifelse in all rules, I wonder if that's expensive.
Edit: as David points out here: https://github.com/JuliaDiff/ForwardDiff.jl/issues/547#issuecomment-1016528436 this is something very close to re-inventing ForwardDiff's nan-safe mode. That had some speed penalty; I also see more branches than seem necessary.
Edit': This also affects anything using derivatives_given_output. Which is still marked experimental. I guess it might be a reason to re-think it; perhaps the multiplication ought to happen inside the function, so that (for functions like sqrt) this can do careful things.
Since this particular issue in multiple ADs has been noticed by several users over the last few months, it would be good to push a fix soon. While it might take more discussion (and benchmarking) to decide on an equivalent to a NaN-safe mode for ChainRules, I think we're in agreement that we need something like this specifically for sqrt, and we can do this now.