Make the addition of thunks use `add!!`
This PR wants to use add!! to add thunks, which should be safe if the result of unthunking is always an array we are free to mutate. And it adds add!! methods to try to add any pair of thunks the quickest way.
julia> using ChainRulesCore, ChainRules, BenchmarkTools
julia> const m100 = rand(128, 100);
julia> th = @thunk -m100;
julia> @btime unthunk($th); # 100kb
min 2.521 μs, mean 13.264 μs (2 allocations, 100.05 KiB)
julia> @btime $th + m100;
min 8.347 μs, mean 27.694 μs (4 allocations, 200.09 KiB) # before
min 12.083 μs, mean 19.496 μs (2 allocations, 100.05 KiB) # after
julia> @btime m100 + $th;
min 8.486 μs, mean 32.362 μs (4 allocations, 200.09 KiB)
min 12.000 μs, mean 18.081 μs (2 allocations, 100.05 KiB)
julia> @btime $th + $th;
min 11.708 μs, mean 53.798 μs (6 allocations, 300.14 KiB)
min 15.042 μs, mean 42.060 μs (4 allocations, 200.09 KiB)
julia> ith = rrule(sum, m100)[2](1.0)[2];
julia> @btime unthunk($ith.val);
min 1.965 μs, mean 13.766 μs (3 allocations, 100.06 KiB)
julia> @btime $th + $ith;
min 11.417 μs, mean 74.748 μs (7 allocations, 300.16 KiB)
min 11.750 μs, mean 32.771 μs (2 allocations, 100.05 KiB)
julia> @btime $ith + $th;
min 11.500 μs, mean 57.503 μs (7 allocations, 300.16 KiB)
min 11.875 μs, mean 34.882 μs (2 allocations, 100.05 KiB)
julia> @btime $ith + $ith;
min 10.167 μs, mean 58.167 μs (8 allocations, 300.17 KiB) # before
min 10.125 μs, mean 25.396 μs (3 allocations, 100.06 KiB) # after
Closes #529.
Edit: also closes #297, which I didn't see. That proposes the same rule, that any array coming from unthunk we should be allowed to mutate.
~~Does not fix the bug that unthunk(ith) seems to make 2 copies, not sure why.~~ (This was just my use of an integer by mistake.)
Codecov Report
Merging #539 (c0c319a) into main (fbb4936) will decrease coverage by
1.27%. The diff coverage is33.33%.
@@ Coverage Diff @@
## main #539 +/- ##
==========================================
- Coverage 93.15% 91.87% -1.28%
==========================================
Files 15 15
Lines 891 911 +20
==========================================
+ Hits 830 837 +7
- Misses 61 74 +13
| Impacted Files | Coverage Δ | |
|---|---|---|
| src/tangent_arithmetic.jl | 94.25% <0.00%> (-2.22%) |
:arrow_down: |
| src/accumulation.jl | 76.00% <38.88%> (-21.23%) |
:arrow_down: |
| src/projection.jl | 97.34% <0.00%> (+0.04%) |
:arrow_up: |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
Accumulating N terms naiively needs 2N-1 allocations: N to make the terms and then N-1 to add them pairwise.
If each term makes a thunk, this PR would reduce that to N. The benefit is mostly from @thunk marking an array as being safe to overwrite. It would still apply if you disobeyed the doc advice and put @thunk even when there is no computation to defer.
If each term makes an inplace-thunk, this PR would reduce it to N-1. Only the first addition is better, after that you have array + ithunk, and this + does not know where the matrix came from. So all the InplaceableThunk machinery doesn't seem to save you much.
Ideally with inplace-thunk you should get down to 1 allocation, with all the rest updating it. But array + ithunk can't know that the array didn't come from the pullback of +. We need some marker.
Perhaps the obvious thing to do is to use a thunk as the marker. We can make thunk + thunk return another thunk, only to mark the array as safe to write into. It could be a Thunk but could also be some new TrivialThunk which doesn't bother with the closure. Anything downstream must already call unthunk if it wants an array, which would still work.