Avoid NaN-propagation in scalar rules
As proposed in https://github.com/JuliaDiff/ChainRules.jl/issues/576#issuecomment-1016443581, this PR makes zero (co)tangents behave as strong zeros in rules defined by @scalar_rule, so that regardless of the value of the partial, if an input (co)tangent is zero, then its product with the partial is also zero.
Before:
julia> frule((NoTangent(), 0.0), sqrt, 0.0)
(0.0, NaN)
This PR:
julia> frule((NoTangent(), 0.0), sqrt, 0.0)
(0.0, 0.0)
This feature is similar to ForwardDiff's NaN-safe mode, which the docs note is 5-10% slower in their benchmarks. However, this benchmark doesn't indicate a consistent performance decrease:
using ChainRules, ChainRulesCore, BenchmarkTools, Random
myhypot(a, b, c) = hypot(a, b, c)
@scalar_rule myhypot(a::Real, b::Real, c::Real) @setup(z = inv(Ω)) (z * a, z * b, z * c)
x = rand(MersenneTwister(42), 1000)
struct MyRuleConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} end
function ChainRulesCore.rrule_via_ad(cfg::MyRuleConfig, f, args...; kwargs...)
return rrule(cfg, f, args...; kwargs...)
end
jvp(f, x, ẋ) = frule(MyRuleConfig(), (NoTangent(), ẋ), f, x)
function j′vp(f, ȳ, x...)
y, back = rrule(MyRuleConfig(), f, x...)
return map(unthunk, Base.tail(back(ȳ)))
end
suite = BenchmarkGroup()
suite["jvp"] = BenchmarkGroup()
suite["jvp"]["inv"] = @benchmarkable jvp(inv, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["sqrt"] = @benchmarkable jvp(sqrt, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["cbrt"] = @benchmarkable jvp(cbrt, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log"] = @benchmarkable jvp(log, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log2"] = @benchmarkable jvp(log2, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log10"] = @benchmarkable jvp(log10, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log1p"] = @benchmarkable jvp(log1p, $(Ref(-1.0))[], $(Ref(0.0))[])
suite["j′vp"] = BenchmarkGroup()
suite["j′vp"]["inv"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], inv, $x)
suite["j′vp"]["sqrt"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], sqrt, $x)
suite["j′vp"]["cbrt"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], cbrt, $x)
suite["j′vp"]["log"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], log, $x)
suite["j′vp"]["log2"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], log2, $x)
suite["j′vp"]["log10"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], log10, $x)
suite["j′vp"]["log1p"] = @benchmarkable j′vp(sum, $(Ref(-1.0))[], log1p, $x)
suite["j′vp"]["myhypot"] =
@benchmarkable j′vp(myhypot, $(Ref(0.0))[], $(Ref(0.0))[], $(Ref(0.0))[], $(Ref(0.0))[])
tune!(suite)
results = run(suite);
mean(results)
Before:
2-element BenchmarkTools.BenchmarkGroup:
tags: []
"jvp" => 7-element BenchmarkTools.BenchmarkGroup:
tags: []
"cbrt" => TrialEstimate(2.442 ns)
"log" => TrialEstimate(2.221 ns)
"sqrt" => TrialEstimate(2.361 ns)
"log2" => TrialEstimate(2.453 ns)
"log1p" => TrialEstimate(2.572 ns)
"log10" => TrialEstimate(2.675 ns)
"inv" => TrialEstimate(1.243 ns)
"j′vp" => 8-element BenchmarkTools.BenchmarkGroup:
tags: []
"cbrt" => TrialEstimate(11.483 μs)
"log" => TrialEstimate(9.677 μs)
"sqrt" => TrialEstimate(6.867 μs)
"log2" => TrialEstimate(9.708 μs)
"log1p" => TrialEstimate(12.228 μs)
"log10" => TrialEstimate(11.265 μs)
"myhypot" => TrialEstimate(5.421 ns)
"inv" => TrialEstimate(5.240 μs)
This PR:
2-element BenchmarkTools.BenchmarkGroup:
tags: []
"jvp" => 7-element BenchmarkTools.BenchmarkGroup:
tags: []
"cbrt" => TrialEstimate(2.574 ns)
"log" => TrialEstimate(2.683 ns)
"sqrt" => TrialEstimate(1.477 ns)
"log2" => TrialEstimate(2.475 ns)
"log1p" => TrialEstimate(2.933 ns)
"log10" => TrialEstimate(2.964 ns)
"inv" => TrialEstimate(1.245 ns)
"j′vp" => 8-element BenchmarkTools.BenchmarkGroup:
tags: []
"cbrt" => TrialEstimate(11.344 μs)
"log" => TrialEstimate(9.320 μs)
"sqrt" => TrialEstimate(6.841 μs)
"log2" => TrialEstimate(9.516 μs)
"log1p" => TrialEstimate(12.917 μs)
"log10" => TrialEstimate(9.932 μs)
"myhypot" => TrialEstimate(6.075 ns)
"inv" => TrialEstimate(4.947 μs)
The current tests fail because
zero(::NotImplemented)throws aNotImplementedExceptionzero(NoTangent())is aZeroTangent(), so this change breaks inferrability. This causes the ChainRules integration tests to fail forcopysignandldexp, for which one of the partials isNoTangent()
Sorry for the slow reply
Not yet having looked at the code, but the general idea that the zero from @scalar_rule should be a strong zero is correct.
And indeed it used to be a ZeroTangent(), but we changed it due to issues with this causes type widening.
(These issues would be resolved if we had https://github.com/JuliaLang/julia/issues/38241)
zero(::NotImplemented)throws aNotImplementedException
This feels like it should be a ZeroTangent()
As it is one of the things that is in fact safe to do even if the tangent wasn't implemented.
Since we don't care about it
zero(NoTangent()) is a ZeroTangent()
This could well be changed to zero(NoTangent()) isa NoTangent)
I just ran into the sqrt issue. Do you think you'll be able to finish and merge this PR soonish, @sethaxen? Or would you like some help here?
@devmotion thanks for the reminder; this slipped off the end of my to-do list. I'll prioritize finishing this in the next few days. I'll let you know when it's ready for a final review.
Codecov Report
Base: 93.16% // Head: 93.17% // Increases project coverage by +0.01% :tada:
Coverage data is based on head (
9983b71) compared to base (9c8fcd2). Patch coverage: 100.00% of modified lines in pull request are covered.
Additional details and impacted files
@@ Coverage Diff @@
## main #551 +/- ##
==========================================
+ Coverage 93.16% 93.17% +0.01%
==========================================
Files 15 15
Lines 907 909 +2
==========================================
+ Hits 845 847 +2
Misses 62 62
| Impacted Files | Coverage Δ | |
|---|---|---|
| src/rule_definition_tools.jl | 96.85% <100.00%> (ø) |
|
| src/tangent_types/abstract_zero.jl | 96.29% <100.00%> (+0.29%) |
:arrow_up: |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
zero(::NotImplemented) throws a NotImplementedException
This feels like it should be a ZeroTangent() As it is one of the things that is in fact safe to do even if the tangent wasn't implemented. Since we don't care about it
It seems it was changed at some point but maybe it should be reverted. It seems it causes type inference issues since now the return type can't be inferred if NotImplemented is involved.
Edit: I think I misread the logs, it seems thr type inference issues are actually caused by zero(::NoTangent) = ZeroTangent()?
Edit2: Just noticed that this was already discussed above (eg https://github.com/JuliaDiff/ChainRulesCore.jl/pull/551#issuecomment-1066185554). Sorry for the noise, I guess I should not have commented on my phone without checking the PR carefully.
@devmotion this should be ready for review now.
So it seems in the end it is not necessary to change
zero(::NotImplemented)andzero(::Type{<:NotImplemented})? Maybe to be sure add tests with a@scalar_rulewhere one partial isNotImplemented(similar to some of the rules in SpecialFunctions) and test them withZeroTangent,NoTangent, and0.0(similar to the tests forsuminv)?
It's not necessary in the sense that we're keeping the old behavior for NotImplemented, The right thing to do is probably @oxinabox's suggestion in https://github.com/JuliaDiff/ChainRulesCore.jl/pull/551#issuecomment-1129121382 (making zero(::NotImplemented) = ZeroTangent(), but this causes inferred types of scalar rules with non-implemented (co)tangents to be type unions. See e.g. https://github.com/JuliaDiff/ChainRulesCore.jl/actions/runs/3255105837/jobs/5344073780#step:6:204.
Or do you think that's acceptable for these rules?