ChainRules.jl
ChainRules.jl copied to clipboard
Avoid NaN (co)tangents for sqrt(0)
This PR fixes #576 by treating zero (co)tangents in sqrt as strong zeros.
It partially fixes https://github.com/FluxML/Zygote.jl/issues/1101 also, but to fix it entirely, we would need to do the same thing to the rule for ^.
Benchmark
This simple benchmark indicates that the performance decrease from this modified rule in Zygote is not extreme.
julia> using Zygote, BenchmarkTools, Random
julia> x = zeros(1_000);
julia> y = rand(MersenneTwister(42), 1_000);
julia> f(x) = sum(x -> max(sqrt(x), 1), x)
f (generic function with 1 method)
julia> Zygote.gradient(f, x)
([NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN],)
julia> b1_1 = @benchmark $(Zygote.gradient)($f, $x)
BenchmarkTools.Trial: 10000 samples with 3 evaluations.
Range (min … max): 8.663 μs … 730.817 μs ┊ GC (min … max): 0.00% … 93.96%
Time (median): 9.755 μs ┊ GC (median): 0.00%
Time (mean ± σ): 13.060 μs ± 31.542 μs ┊ GC (mean ± σ): 15.00% ± 6.24%
▅██▆▇▆▅▄▃▂▁▂▂▂▂▂▁▁▁▁▁▂▂▂▁▁ ▁▁▁ ▂
███████████████████████████▇█▇▇▇▆▇▇▆▆▄▆▆▆▆▇█▇▇▆▆▆████▇▆▅▄▇▆▆ █
8.66 μs Histogram: log(frequency) by time 26 μs <
Memory estimate: 71.22 KiB, allocs estimate: 31.
julia> b2_1 = @benchmark $(Zygote.gradient)($f, $y)
BenchmarkTools.Trial: 10000 samples with 3 evaluations.
Range (min … max): 8.612 μs … 532.969 μs ┊ GC (min … max): 0.00% … 93.85%
Time (median): 9.335 μs ┊ GC (median): 0.00%
Time (mean ± σ): 12.282 μs ± 30.078 μs ┊ GC (mean ± σ): 16.07% ± 6.44%
▅██▆▆▆▅▃▂ ▁ ▂▂▂▂▁▁ ▁ ▂
██████████████████████████████▇▇▆▆▅▆▆▄▇▅▄▅▅▆▅▅▆▅▅▆▆▆▆▅▅▅▄▄▅▅ █
8.61 μs Histogram: log(frequency) by time 23.9 μs <
Memory estimate: 71.22 KiB, allocs estimate: 31.
julia> function ChainRulesCore.frule((_, Δx), ::typeof(sqrt), x::Number)
Ω = sqrt(x)
∂Ω = Δx / 2Ω
return Ω, ifelse(iszero(Δx) & iszero(x), zero(∂Ω), ∂Ω)
end
julia> function ChainRulesCore.rrule(::typeof(sqrt), x::Number)
Ω = sqrt(x)
function sqrt_pullback(ΔΩ)
∂x = ΔΩ / 2conj(Ω)
return (
NoTangent(),
ProjectTo(x)(ifelse(iszero(ΔΩ) & iszero(x), zero(∂x), ∂x))
)
end
return Ω, sqrt_pullback
end
julia> Zygote.gradient(f, x)
([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],)
julia> b1_2 = @benchmark $(Zygote.gradient)($f, $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 8.891 μs … 1.357 ms ┊ GC (min … max): 0.00% … 96.56%
Time (median): 9.832 μs ┊ GC (median): 0.00%
Time (mean ± σ): 12.484 μs ± 46.625 μs ┊ GC (mean ± σ): 15.40% ± 4.09%
▆██▆▅▃▃▃▃▂▁ ▁▁▁ ▂
▇████████████████▇▇█▇▆▄▃▄▃▃▄▄▄▅▃▄▂▄▃▂▄▄▃▄▄▄▂▄▄▃▄▅▄▅▄▆▇▇▆▇▆▆ █
8.89 μs Histogram: log(frequency) by time 25.3 μs <
Memory estimate: 86.84 KiB, allocs estimate: 31.
julia> b2_2 = @benchmark $(Zygote.gradient)($f, $y)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 9.066 μs … 1.304 ms ┊ GC (min … max): 0.00% … 96.15%
Time (median): 9.892 μs ┊ GC (median): 0.00%
Time (mean ± σ): 11.970 μs ± 43.215 μs ┊ GC (mean ± σ): 14.43% ± 3.96%
▁▆██▆▄▁
▂▃▅████████▆▅▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂ ▃
9.07 μs Histogram: frequency by time 15.7 μs <
Memory estimate: 86.84 KiB, allocs estimate: 31.
julia> judge(mean(b1_2), mean(b1_1))
BenchmarkTools.TrialJudgement:
time: -4.41% => invariant (5.00% tolerance)
memory: +21.94% => regression (1.00% tolerance)
julia> judge(mean(b2_2), mean(b2_1))
BenchmarkTools.TrialJudgement:
time: -2.54% => invariant (5.00% tolerance)
memory: +21.94% => regression (1.00% tolerance)
This seems fine. How many other functions will need this? cbrt is currently a scalar rule. Powers are their own messy thing.
This seems fine. How many other functions will need this?
cbrtis currently a scalar rule. Powers are their own messy thing.
At first glance, /, \, ^, inv, sqrt, cbrt, log, log2, log10, log1p, and a bunch of inverse trig/hyperbolic functions.
I'm reluctant to add custom frules/rrules for all of these without first at least checking if we see a significant performance decrease by making zero (co)tangents strong zeros in the @scalar_rule macro, so perhaps before merging this I should open a PR on ChainRulesCore with a benchmark.
I opened a PR to ChainRulesCore that would supersede this one if merged: https://github.com/JuliaDiff/ChainRulesCore.jl/pull/551
Functions like inv, log etc. are a slightly different class to sqrt, since the primal is infinite.
The motivating case for sqrt is I think something like f(x) = sqrt(x^2 + 0), which is regular at zero, and can be made to have a continuous derivative there. Is there something like that for inv, less trivial than g(x) = inv(inv(x))?
Functions like
inv,logetc. are a slightly different class tosqrt, since the primal is infinite.
Is this difference important though? There are plenty of cases where in a well-behaved primal function intermediate can be non-finite, resulting in introduction of NaNs. Here's another one that hits users of lower-truncated normal distributions in Turing:
julia> using StatsFuns
julia> normcdf(0.0, 1.0, Inf) # a constant function for all finite values of mu and sigma
1.0
julia> FiniteDifferences.grad(central_fdm(5, 1), x -> normcdf(0.0, x, Inf), 1.0)
(6.085449991639748e-14,)
julia> Zygote.gradient(x -> normcdf(0.0, x, Inf), 1.0)
(NaN,)
This happens because the gradient of erfc at Inf is 0, but when that gets pulled back through (x - mu)/sigma for x=Inf, we have an infinite partial for /, so a NaN is introduced. This case is also resolved by treating zero (co)tangents as hard zeros.
The motivating case for sqrt is I think something like
f(x) = sqrt(x^2 + 0), which is regular at zero, and can be made to have a continuous derivative there. Is there something like that forinv, less trivial thang(x) = inv(inv(x))?
Perhaps, but I don't see inv(inv(x)) (or log(exp(x))) as being any more trivial than sqrt(x^2).