ChainRules.jl
ChainRules.jl copied to clipboard
Configured rule for `maximum(f, xs)`
This uses the RuleConfig{>:HasReverseMode} story to call back into AD to write a rule for maximum(f, xs).
It's much simplified from the first attempt:
- On julia 1.7+, for a total reduction, it calls
i = findmax(f, xs), and then usesrrule_via_ad(f, xs[i]). - Otherwise, it just calls broadcasting.
Fast case, before & after:
julia> @btime gradient(x -> sum(maximum(sqrt, x)), $(rand(30,30)));
min 2.908 ms, mean 4.031 ms (21816 allocations, 9.29 MiB) # before
min 14.875 μs, mean 17.546 μs (52 allocations, 8.92 KiB) # after
julia> @btime gradient(x -> sum(maximum(sqrt.(x))), $(rand(30,30)));
min 17.500 μs, mean 25.453 μs (46 allocations, 36.83 KiB) # just broadcasting, to compare
Before this PR, gradient(x -> sum(maximum(sqrt, x, dims=1)), (rand(30,30))) gives an error with Zygote. After, it is the same speed as broadcasting.
What doesn't seem easy now is testing the broadcast path.
First attempt
However, it only needs one such call, rather than one for every element. That means it ends up calling f say N^2 + 1 times for a matrix (or N^2 + N with dims). This is much more efficient than calling it via AD all N^2 times, saving the pullbacks somewhere, and calling just one. Not always faster than Zygote's current broadcasting (which uses ForwardDiff), but much less memory:
julia> @btime gradient(x -> sum(maximum(sqrt, x)), $(rand(30,30)));
9.625 μs (73 allocations: 9.11 KiB) # this PR
9.333 μs (66 allocations: 8.95 KiB) # this PR, with rrule instead of rrule_via_ad
julia> @btime gradient(x -> sum(maximum(sqrt, x, dims=1)), $(rand(30,30)));
10.125 μs (34 allocations: 13.92 KiB) # this PR, take 1
15.208 μs (33 allocations: 29.31 KiB) # this PR, with mask allowing multiple maxima
17.166 μs (33 allocations: 29.31 KiB) # with rrule instead of rrule_via_ad
julia> @btime gradient(x -> sum(maximum(sqrt.(x))), $(rand(30,30)));
8.833 μs (48 allocations: 36.98 KiB) # broadcasting with Duals
julia> @btime maximum(sqrt, $(rand(30,30))); # forward pass
1.438 μs (0 allocations: 0 bytes)
If this is OK, then perhaps the sum(f, x) rule from #441 should also consider calling f more times. There's a commit here doing that, which cuts the memory use by quite a bit. Perhaps there are functions f for which calling twice would be slower? Perhaps writing sum(f, x) vs. sum(f.(x)) is how you emphasise that you care more about memory? ~~(It may make sense to remove this & discuss sum in another thread.)~~ [Now removed here.]
julia> @btime gradient(x -> sum(sqrt, x), $(rand(30,30)));
4.173 μs (16 allocations: 50.02 KiB) # before
1.954 μs (2 allocations: 7.20 KiB) # after
julia> @btime gradient(x -> sum(sum(sqrt, x, dims=1)), $(rand(30,30)));
10.625 μs (42 allocations: 51.47 KiB) # before
2.704 μs (18 allocations: 8.20 KiB) # after
# Compare broadcasting:
julia> @btime gradient(x -> sum(sqrt.(x)), $(rand(30,30)));
2.616 μs (10 allocations: 28.70 KiB)
julia> @btime gradient(x -> sum(sum(sqrt.(x), dims=1)), $(rand(30,30)));
3.542 μs (26 allocations: 36.81 KiB)
# Forward only:
julia> @btime sum(sqrt, x) setup=(x=$(rand(30,30)));
833.333 ns (0 allocations: 0 bytes)
julia> @btime sum(sqrt.(x)) setup=(x=$(rand(30,30)));
873.544 ns (1 allocation: 7.19 KiB)
All WIP, needs more careful testing, etc.
First attempt
With a more expensive function:
julia> @btime gradient(x -> sum(maximum(log∘exp, x)), $(rand(30,30)));
34.791 μs (162 allocations: 11.11 KiB)
julia> @btime gradient(x -> sum(maximum(log∘exp, x, dims=1)), $(rand(30,30)));
326.292 μs (2615 allocations: 87.55 KiB)
julia> @btime gradient(x -> sum(maximum((log∘exp).(x))), $(rand(30,30)));
22.333 μs (48 allocations: 36.86 KiB)
julia> @btime gradient(x -> sum(maximum((log∘exp).(x), dims=1)), $(rand(30,30)));
16.250 μs (13 allocations: 36.72 KiB)
# without AD:
julia> @btime maximum(log∘exp, $(rand(30,30)));
13.000 μs (0 allocations: 0 bytes)
julia> @btime maximum(log∘exp, $(rand(30,30)), dims=1);
15.500 μs (4 allocations: 416 bytes)
julia> @btime findmax(log∘exp, $(rand(30,30)));
15.334 μs (0 allocations: 0 bytes)
The dims=1 case is very slow, because (1) it's taking a second complete (N^2) pass to find the indices at which this attains the maximum, since there is no findmax(sqrt, rand(3,3), dims=1), and (2) it needs N calls to rrule_via_ad, and this doesn't infer for log∘exp, like Zygote's generic broadcasting.
The broadcasted one uses dual numbers, which is much quicker. Note BTW that there is no chunk mode in play here -- it always evaluates f exactly 900 times.
I'm not so sure why the complete reduction is slower than broadcasting here, but it's much closer, and 3x less memory.
Diffractor, BTW, does not see this rule. It does see #480, but broadcast times are variable:
julia> @btime Diffractor.gradient(x -> maximum(sqrt, x), $(rand(30,30)));
ERROR: TypeError: in typeassert, expected Int64, got a value of type Nothing
...
[8] (::Diffractor.∂⃖recurse{1})(::typeof(Base._mapreduce), ::typeof(sqrt), ::typeof(max), ::IndexLinear, ::Matrix{Float64})
julia> @btime gradient(x -> maximum(sqrt.(x)), $(rand(30,30)));
11.417 μs (12 allocations: 64.33 KiB) # Zygote 8.833 μs (48 allocations: 36.98 KiB)
julia> @btime gradient(x -> maximum((log∘exp).(x)), $(rand(30,30)));
2.155 ms (17143 allocations: 586.41 KiB) # Zygote 22.333 μs (48 allocations: 36.86 KiB)
This has been much simplified. For the case of a complete reduction only, maximum(f, x), this saves the position of the maximum, and calls rrule_via_ad(f, x[i]) once. This saves memory compared to broadcasting, but in the end not much time -- might still not be worth the complication:
julia> @btime gradient(x -> sum(maximum(sqrt, x)), $(rand(30,30))); # this PR + Zygote + Julia 1.8
min 8.625 μs, mean 10.906 μs (52 allocations, 8.92 KiB. GC mean 13.94%)
julia> @btime gradient(x -> sum(maximum(sqrt.(x))), $(rand(30,30)));
min 10.041 μs, mean 16.087 μs (49 allocations, 36.88 KiB. GC mean 20.75%)
julia> @btime gradient(x -> sum(maximum(log∘exp, x)), $(rand(30,30))); # with a more expensive function:
min 20.208 μs, mean 22.335 μs (116 allocations, 10.88 KiB. GC mean 5.22%)
julia> @btime gradient(x -> sum(maximum((log∘exp).(x))), $(rand(30,30)));
min 19.291 μs, mean 25.757 μs (49 allocations, 36.88 KiB. GC mean 13.03%)
julia> @btime maximum(log∘exp, $(rand(30,30)));
min 8.958 μs, mean 9.128 μs (0 allocations)
That means it calls f in total N+1 times. If f is stateful, then as far as I know the result of maximum(f, x) is already ill-defined, no order is guaranteed. If f closes over something, that will get a gradient contribution only from one entry, should be fine.
Instead of using rrule_via_ad, this would be a good use case for derivatives_given_output when that's defined.
For cases with dims, it just calls broadcasting. Earlier commits tried to handle this, but it gets complicated, and the saving is less clear. This case is not so easy to test.
On Julia 1.6 and below, the method findmax(f, x) which the fast path needs doesn't exist, so it always calls broadcasting.
Status here is as in (edited) first message above.
Perhaps the broadcast path can be easily tested using https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/243 once that's available.