ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Avoid NaN-propagation in scalar rules

Open sethaxen opened this issue 3 years ago • 21 comments

As proposed in https://github.com/JuliaDiff/ChainRules.jl/issues/576#issuecomment-1016443581, this PR makes zero (co)tangents behave as strong zeros in rules defined by @scalar_rule, so that regardless of the value of the partial, if an input (co)tangent is zero, then its product with the partial is also zero.

Before:

julia> frule((NoTangent(), 0.0), sqrt, 0.0)
(0.0, NaN)

This PR:

julia> frule((NoTangent(), 0.0), sqrt, 0.0)
(0.0, 0.0)

This feature is similar to ForwardDiff's NaN-safe mode, which the docs note is 5-10% slower in their benchmarks. However, this benchmark doesn't indicate a consistent performance decrease:

using ChainRules, ChainRulesCore, BenchmarkTools, Random

myhypot(a, b, c) = hypot(a, b, c)
@scalar_rule myhypot(a::Real, b::Real, c::Real) @setup(z = inv(Ω)) (z * a, z * b, z * c)
x = rand(MersenneTwister(42), 1000)

struct MyRuleConfig <: RuleConfig{Union{HasForwardsMode,HasReverseMode}} end
function ChainRulesCore.rrule_via_ad(cfg::MyRuleConfig, f, args...; kwargs...)
    return rrule(cfg, f, args...; kwargs...)
end

jvp(f, x, ẋ) = frule(MyRuleConfig(), (NoTangent(), ẋ), f, x)

function j′vp(f, ȳ, x...)
    y, back = rrule(MyRuleConfig(), f, x...)
    return map(unthunk, Base.tail(back(ȳ)))
end

suite = BenchmarkGroup()
suite["jvp"] = BenchmarkGroup()
suite["jvp"]["inv"] = @benchmarkable jvp(inv, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["sqrt"] = @benchmarkable jvp(sqrt, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["cbrt"] = @benchmarkable jvp(cbrt, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log"] = @benchmarkable jvp(log, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log2"] = @benchmarkable jvp(log2, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log10"] = @benchmarkable jvp(log10, $(Ref(0.0))[], $(Ref(0.0))[])
suite["jvp"]["log1p"] = @benchmarkable jvp(log1p, $(Ref(-1.0))[], $(Ref(0.0))[])

suite["j′vp"] = BenchmarkGroup()
suite["j′vp"]["inv"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], inv, $x)
suite["j′vp"]["sqrt"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], sqrt, $x)
suite["j′vp"]["cbrt"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], cbrt, $x)
suite["j′vp"]["log"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], log, $x)
suite["j′vp"]["log2"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], log2, $x)  
suite["j′vp"]["log10"] = @benchmarkable j′vp(sum, $(Ref(0.0))[], log10, $x)
suite["j′vp"]["log1p"] = @benchmarkable j′vp(sum, $(Ref(-1.0))[], log1p, $x)
suite["j′vp"]["myhypot"] =
    @benchmarkable j′vp(myhypot, $(Ref(0.0))[], $(Ref(0.0))[], $(Ref(0.0))[], $(Ref(0.0))[])

tune!(suite)
results = run(suite);
mean(results)

Before:

2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "jvp" => 7-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  "cbrt" => TrialEstimate(2.442 ns)
	  "log" => TrialEstimate(2.221 ns)
	  "sqrt" => TrialEstimate(2.361 ns)
	  "log2" => TrialEstimate(2.453 ns)
	  "log1p" => TrialEstimate(2.572 ns)
	  "log10" => TrialEstimate(2.675 ns)
	  "inv" => TrialEstimate(1.243 ns)
  "j′vp" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  "cbrt" => TrialEstimate(11.483 μs)
	  "log" => TrialEstimate(9.677 μs)
	  "sqrt" => TrialEstimate(6.867 μs)
	  "log2" => TrialEstimate(9.708 μs)
	  "log1p" => TrialEstimate(12.228 μs)
	  "log10" => TrialEstimate(11.265 μs)
	  "myhypot" => TrialEstimate(5.421 ns)
	  "inv" => TrialEstimate(5.240 μs)

This PR:

2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "jvp" => 7-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  "cbrt" => TrialEstimate(2.574 ns)
	  "log" => TrialEstimate(2.683 ns)
	  "sqrt" => TrialEstimate(1.477 ns)
	  "log2" => TrialEstimate(2.475 ns)
	  "log1p" => TrialEstimate(2.933 ns)
	  "log10" => TrialEstimate(2.964 ns)
	  "inv" => TrialEstimate(1.245 ns)
  "j′vp" => 8-element BenchmarkTools.BenchmarkGroup:
	  tags: []
	  "cbrt" => TrialEstimate(11.344 μs)
	  "log" => TrialEstimate(9.320 μs)
	  "sqrt" => TrialEstimate(6.841 μs)
	  "log2" => TrialEstimate(9.516 μs)
	  "log1p" => TrialEstimate(12.917 μs)
	  "log10" => TrialEstimate(9.932 μs)
	  "myhypot" => TrialEstimate(6.075 ns)
	  "inv" => TrialEstimate(4.947 μs)

sethaxen avatar Mar 13 '22 20:03 sethaxen

The current tests fail because

  1. zero(::NotImplemented) throws a NotImplementedException
  2. zero(NoTangent()) is a ZeroTangent(), so this change breaks inferrability. This causes the ChainRules integration tests to fail for copysign and ldexp, for which one of the partials is NoTangent()

sethaxen avatar Mar 13 '22 21:03 sethaxen

Sorry for the slow reply

Not yet having looked at the code, but the general idea that the zero from @scalar_rule should be a strong zero is correct. And indeed it used to be a ZeroTangent(), but we changed it due to issues with this causes type widening. (These issues would be resolved if we had https://github.com/JuliaLang/julia/issues/38241)

zero(::NotImplemented) throws a NotImplementedException

This feels like it should be a ZeroTangent() As it is one of the things that is in fact safe to do even if the tangent wasn't implemented. Since we don't care about it

zero(NoTangent()) is a ZeroTangent()

This could well be changed to zero(NoTangent()) isa NoTangent)

oxinabox avatar May 17 '22 17:05 oxinabox

I just ran into the sqrt issue. Do you think you'll be able to finish and merge this PR soonish, @sethaxen? Or would you like some help here?

devmotion avatar Oct 13 '22 14:10 devmotion

@devmotion thanks for the reminder; this slipped off the end of my to-do list. I'll prioritize finishing this in the next few days. I'll let you know when it's ready for a final review.

sethaxen avatar Oct 13 '22 14:10 sethaxen

Codecov Report

Base: 93.16% // Head: 93.17% // Increases project coverage by +0.01% :tada:

Coverage data is based on head (9983b71) compared to base (9c8fcd2). Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #551      +/-   ##
==========================================
+ Coverage   93.16%   93.17%   +0.01%     
==========================================
  Files          15       15              
  Lines         907      909       +2     
==========================================
+ Hits          845      847       +2     
  Misses         62       62              
Impacted Files Coverage Δ
src/rule_definition_tools.jl 96.85% <100.00%> (ø)
src/tangent_types/abstract_zero.jl 96.29% <100.00%> (+0.29%) :arrow_up:

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

codecov-commenter avatar Oct 13 '22 14:10 codecov-commenter

zero(::NotImplemented) throws a NotImplementedException

This feels like it should be a ZeroTangent() As it is one of the things that is in fact safe to do even if the tangent wasn't implemented. Since we don't care about it

It seems it was changed at some point but maybe it should be reverted. It seems it causes type inference issues since now the return type can't be inferred if NotImplemented is involved.

Edit: I think I misread the logs, it seems thr type inference issues are actually caused by zero(::NoTangent) = ZeroTangent()?

Edit2: Just noticed that this was already discussed above (eg https://github.com/JuliaDiff/ChainRulesCore.jl/pull/551#issuecomment-1066185554). Sorry for the noise, I guess I should not have commented on my phone without checking the PR carefully.

devmotion avatar Oct 13 '22 15:10 devmotion

@devmotion this should be ready for review now.

sethaxen avatar Oct 15 '22 11:10 sethaxen

So it seems in the end it is not necessary to change zero(::NotImplemented) and zero(::Type{<:NotImplemented})? Maybe to be sure add tests with a @scalar_rule where one partial is NotImplemented (similar to some of the rules in SpecialFunctions) and test them with ZeroTangent, NoTangent, and 0.0 (similar to the tests for suminv)?

It's not necessary in the sense that we're keeping the old behavior for NotImplemented, The right thing to do is probably @oxinabox's suggestion in https://github.com/JuliaDiff/ChainRulesCore.jl/pull/551#issuecomment-1129121382 (making zero(::NotImplemented) = ZeroTangent(), but this causes inferred types of scalar rules with non-implemented (co)tangents to be type unions. See e.g. https://github.com/JuliaDiff/ChainRulesCore.jl/actions/runs/3255105837/jobs/5344073780#step:6:204.

Or do you think that's acceptable for these rules?

sethaxen avatar Oct 15 '22 22:10 sethaxen