ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Current rules for sqrt produce NaN for zero primal and (co)tangents

Open sethaxen opened this issue 3 years ago • 2 comments

This only happens when the (co)tangent is 0.

julia> using ChainRules

julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.0), sqrt, 0.0)
(0.0, NaN)

julia> ChainRules.rrule(sqrt, 0.0)[2](0.0)
(ChainRulesCore.NoTangent(), NaN)

I suggest we adopt the convention that the produced (co)tangent in this case should also be 0. This is supported by finite differerences:

julia> using FiniteDifferences

julia> jvp(central_fdm(5, 1), sqrt, (0.0, 0.0))
0.0

julia> j′vp(central_fdm(5, 1), x -> sqrt(clamp(x, 0, Inf)), 0.0, 0.0)
(0.0,)

julia> j′vp(central_fdm(5, 1), sqrt ∘ abs, 0.0, 0.0)
(0.0,)

So instead of using @scalar_rule we would explicitly define the frule and rrule.

sethaxen avatar Jan 19 '22 11:01 sethaxen

So the proposal is to always treat zero tangent (or cotangent) as a strong zero. I think that makes sense.

The rule for x^p already treats Δp being zero strongly, here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/fastmath_able.jl#L171 The logic was that quite often the p is really a constant, and its zero would then turn some otherwise correct infinite derivatives with respect to x into NaNs.

x^0.5 behaves just like sqrt(x) with respect to Δx:

julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.01, 0.0), ^, 0.0, 0.5)
(0.0, Inf)

julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.0, 0.0), ^, 0.0, 0.5)
(0.0, NaN)

julia> ChainRules.rrule(^, 0.0, 0.5)[2](0.01)[2]
Inf

julia> ChainRules.rrule(^, 0.0, 0.5)[2](0.0)[2]
NaN

Treating zero tangent as strong could be done globally in @scalar_rule. Which would mean one more ifelse in all rules, I wonder if that's expensive.

Edit: as David points out here: https://github.com/JuliaDiff/ForwardDiff.jl/issues/547#issuecomment-1016528436 this is something very close to re-inventing ForwardDiff's nan-safe mode. That had some speed penalty; I also see more branches than seem necessary.

Edit': This also affects anything using derivatives_given_output. Which is still marked experimental. I guess it might be a reason to re-think it; perhaps the multiplication ought to happen inside the function, so that (for functions like sqrt) this can do careful things.

mcabbott avatar Jan 19 '22 13:01 mcabbott

Since this particular issue in multiple ADs has been noticed by several users over the last few months, it would be good to push a fix soon. While it might take more discussion (and benchmarking) to decide on an equivalent to a NaN-safe mode for ChainRules, I think we're in agreement that we need something like this specifically for sqrt, and we can do this now.

sethaxen avatar Mar 05 '22 19:03 sethaxen