NNlib.jl
NNlib.jl copied to clipboard
Add gradients for `conv_bias_act`, and a similar `dense_bias_act`
This aims to add gradient definitions for the existing conv_bias_act. That is, however, very much WIP, and I don't recommend anyone try to read it just yet.
It also adds an analogous dense_bias_act, which is closer to done. What this gains you over σ.(w*x .+ b) is memory savings. Zygote will by default un-fuse the broadcast, allocating 3 arrays on the forward pass, but in fact we can often over-write the result of w*x, saving 2 copies. This should happen both on CPU and GPU. There is one more copy you could save on the reverse pass, bringing you to 1/2 the memory usage of before, but only if you were sure that the pullback would only be called once. That isn't true for say Zygote.jacobian, and I don't think there's a way to know when it will be safe. So we save 1/3 not 1/2, when inside Zygote.
I say "often" because over-writing w*x only works when the gradient of σ can be written in terms of its output, without saving its input. That's true for tanh and relu and some others, which are ~~explicitly whitelisted here as INPLACE_ACTS. Surely a more extensible method for that could be invented.~~ now handled using https://github.com/JuliaDiff/ChainRulesCore.jl/pull/453 .
This was written before seeing https://github.com/FluxML/NNlibCPU.jl/pull/1 . But they may work well together -- for instance the function dense! there could (after we adjust signatures a little) simply overload a function here, providing a fast path when that package is loaded. Likewise it can overload conv_bias_act! to run a fused activation-and-convolution on the CPU, a bit like the existing NNlibCUDA routine. (From a first glance it looks like dense! has a trait for deciding which functions are in-place-safe, which is good.) Again, not fully baked, but opened now to start discussing.
Best to let some discussion over on https://github.com/FluxML/NNlibCPU.jl/pull/1
We have wanted to move to fused conv for some cudnn fast paths too, but unsure if it's a good idea to have all these different versions
Is the original message up top still accurate? It looks like the implementation is there. What help is necessary to get this through?
My memory is that this basically worked, but the performance was disappointing due to https://github.com/JuliaLang/julia/issues/43153 . Writing back into the same x (when safe) saved memory but not time, unless you pirated Base things as suggested there. (Which it looks like I didn't do on this branch?)
Edit: ok I've updated things. I think the most honest benchmark looks like this, and shows a serious improvement from tanh_fast. And one copy saved by bias_act!, but now avoiding a serious slowdown, but still slower than ideal, why 71 allocations?
julia> w, b = rand(Float32, 100, 100), rand(Float32, 100); x = rand(Float32, size(w)...);
julia> @btime gradient((w,x,b) -> sum(abs2, dense_bias_act(tanh, w, x, b)), wr[], $x, $b) setup=(wr=Ref(randn(Float32,100,100))) evals=1;
min 44.792 μs, mean 79.901 μs (71 allocations, 198.37 KiB)
julia> @btime gradient((w,x,b) -> sum(abs2, tanh.((w * x) .+ b)), wr[], $x, $b) setup=(wr=Ref(randn(Float32,100,100))) evals=1;
min 114.583 μs, mean 158.989 μs (39 allocations, 275.25 KiB)
julia> @btime gradient((w,x,b) -> sum(abs2, tanh_fast.((w * x) .+ b)), wr[], $x, $b) setup=(wr=Ref(randn(Float32,100,100))) evals=1;
min 40.125 μs, mean 75.140 μs (39 allocations, 275.25 KiB)
Would be worthwhile to benchmark on other computers. (This is M1 + apple's blas.) And on GPUs. And conv... and ideally in bigger examples, whatever happened to https://github.com/FluxML/FluxBench.jl ?
Rebased at https://github.com/mcabbott/NNlib.jl/tree/bias_act_22 after squashing, but its own tests fail.