ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Avoid NaN (co)tangents for sqrt(0)

Open sethaxen opened this issue 3 years ago • 5 comments

This PR fixes #576 by treating zero (co)tangents in sqrt as strong zeros.

It partially fixes https://github.com/FluxML/Zygote.jl/issues/1101 also, but to fix it entirely, we would need to do the same thing to the rule for ^.

Benchmark

This simple benchmark indicates that the performance decrease from this modified rule in Zygote is not extreme.

julia> using Zygote, BenchmarkTools, Random

julia> x = zeros(1_000);

julia> y = rand(MersenneTwister(42), 1_000);

julia> f(x) = sum(x -> max(sqrt(x), 1), x)
f (generic function with 1 method)

julia> Zygote.gradient(f, x)
([NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN  …  NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN],)

julia> b1_1 = @benchmark $(Zygote.gradient)($f, $x)
BenchmarkTools.Trial: 10000 samples with 3 evaluations.
 Range (min … max):   8.663 μs … 730.817 μs  ┊ GC (min … max):  0.00% … 93.96%
 Time  (median):      9.755 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   13.060 μs ±  31.542 μs  ┊ GC (mean ± σ):  15.00% ±  6.24%

  ▅██▆▇▆▅▄▃▂▁▂▂▂▂▂▁▁▁▁▁▂▂▂▁▁                       ▁▁▁         ▂
  ███████████████████████████▇█▇▇▇▆▇▇▆▆▄▆▆▆▆▇█▇▇▆▆▆████▇▆▅▄▇▆▆ █
  8.66 μs       Histogram: log(frequency) by time        26 μs <

 Memory estimate: 71.22 KiB, allocs estimate: 31.

julia> b2_1 = @benchmark $(Zygote.gradient)($f, $y)
BenchmarkTools.Trial: 10000 samples with 3 evaluations.
 Range (min … max):   8.612 μs … 532.969 μs  ┊ GC (min … max):  0.00% … 93.85%
 Time  (median):      9.335 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   12.282 μs ±  30.078 μs  ┊ GC (mean ± σ):  16.07% ±  6.44%

  ▅██▆▆▆▅▃▂ ▁  ▂▂▂▂▁▁     ▁                                    ▂
  ██████████████████████████████▇▇▆▆▅▆▆▄▇▅▄▅▅▆▅▅▆▅▅▆▆▆▆▅▅▅▄▄▅▅ █
  8.61 μs       Histogram: log(frequency) by time      23.9 μs <

 Memory estimate: 71.22 KiB, allocs estimate: 31.

julia> function ChainRulesCore.frule((_, Δx), ::typeof(sqrt), x::Number)
           Ω = sqrt(x)
           ∂Ω = Δx / 2Ω
           return Ω, ifelse(iszero(Δx) & iszero(x), zero(∂Ω), ∂Ω)
       end

julia> function ChainRulesCore.rrule(::typeof(sqrt), x::Number)
           Ω = sqrt(x)
           function sqrt_pullback(ΔΩ)
               ∂x = ΔΩ / 2conj(Ω)
               return (
                   NoTangent(),
                   ProjectTo(x)(ifelse(iszero(ΔΩ) & iszero(x), zero(∂x), ∂x))
               )
           end
           return Ω, sqrt_pullback
       end

julia> Zygote.gradient(f, x)
([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],)

julia> b1_2 = @benchmark $(Zygote.gradient)($f, $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   8.891 μs …  1.357 ms  ┊ GC (min … max):  0.00% … 96.56%
 Time  (median):      9.832 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   12.484 μs ± 46.625 μs  ┊ GC (mean ± σ):  15.40% ±  4.09%

   ▆██▆▅▃▃▃▃▂▁ ▁▁▁                                            ▂
  ▇████████████████▇▇█▇▆▄▃▄▃▃▄▄▄▅▃▄▂▄▃▂▄▄▃▄▄▄▂▄▄▃▄▅▄▅▄▆▇▇▆▇▆▆ █
  8.89 μs      Histogram: log(frequency) by time      25.3 μs <

 Memory estimate: 86.84 KiB, allocs estimate: 31.

julia> b2_2 = @benchmark $(Zygote.gradient)($f, $y)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   9.066 μs …  1.304 ms  ┊ GC (min … max):  0.00% … 96.15%
 Time  (median):      9.892 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   11.970 μs ± 43.215 μs  ┊ GC (mean ± σ):  14.43% ±  3.96%

     ▁▆██▆▄▁                                                   
  ▂▃▅████████▆▅▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂ ▃
  9.07 μs         Histogram: frequency by time        15.7 μs <

 Memory estimate: 86.84 KiB, allocs estimate: 31.

julia> judge(mean(b1_2), mean(b1_1))
BenchmarkTools.TrialJudgement: 
  time:   -4.41% => invariant (5.00% tolerance)
  memory: +21.94% => regression (1.00% tolerance)


julia> judge(mean(b2_2), mean(b2_1))
BenchmarkTools.TrialJudgement: 
  time:   -2.54% => invariant (5.00% tolerance)
  memory: +21.94% => regression (1.00% tolerance)

sethaxen avatar Mar 12 '22 12:03 sethaxen

This seems fine. How many other functions will need this? cbrt is currently a scalar rule. Powers are their own messy thing.

mcabbott avatar Mar 12 '22 16:03 mcabbott

This seems fine. How many other functions will need this? cbrt is currently a scalar rule. Powers are their own messy thing.

At first glance, /, \, ^, inv, sqrt, cbrt, log, log2, log10, log1p, and a bunch of inverse trig/hyperbolic functions.

I'm reluctant to add custom frules/rrules for all of these without first at least checking if we see a significant performance decrease by making zero (co)tangents strong zeros in the @scalar_rule macro, so perhaps before merging this I should open a PR on ChainRulesCore with a benchmark.

sethaxen avatar Mar 12 '22 19:03 sethaxen

I opened a PR to ChainRulesCore that would supersede this one if merged: https://github.com/JuliaDiff/ChainRulesCore.jl/pull/551

sethaxen avatar Mar 13 '22 20:03 sethaxen

Functions like inv, log etc. are a slightly different class to sqrt, since the primal is infinite.

The motivating case for sqrt is I think something like f(x) = sqrt(x^2 + 0), which is regular at zero, and can be made to have a continuous derivative there. Is there something like that for inv, less trivial than g(x) = inv(inv(x))?

mcabbott avatar Mar 13 '22 22:03 mcabbott

Functions like inv, log etc. are a slightly different class to sqrt, since the primal is infinite.

Is this difference important though? There are plenty of cases where in a well-behaved primal function intermediate can be non-finite, resulting in introduction of NaNs. Here's another one that hits users of lower-truncated normal distributions in Turing:

julia> using StatsFuns

julia> normcdf(0.0, 1.0, Inf)  # a constant function for all finite values of mu and sigma
1.0

julia> FiniteDifferences.grad(central_fdm(5, 1), x -> normcdf(0.0, x, Inf), 1.0)
(6.085449991639748e-14,)

julia> Zygote.gradient(x -> normcdf(0.0, x, Inf), 1.0)
(NaN,)

This happens because the gradient of erfc at Inf is 0, but when that gets pulled back through (x - mu)/sigma for x=Inf, we have an infinite partial for /, so a NaN is introduced. This case is also resolved by treating zero (co)tangents as hard zeros.

The motivating case for sqrt is I think something like f(x) = sqrt(x^2 + 0), which is regular at zero, and can be made to have a continuous derivative there. Is there something like that for inv, less trivial than g(x) = inv(inv(x))?

Perhaps, but I don't see inv(inv(x)) (or log(exp(x))) as being any more trivial than sqrt(x^2).

sethaxen avatar Mar 14 '22 00:03 sethaxen