NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

Add gradients for `conv_bias_act`, and a similar `dense_bias_act`

Open mcabbott opened this issue 4 years ago • 4 comments
trafficstars

This aims to add gradient definitions for the existing conv_bias_act. That is, however, very much WIP, and I don't recommend anyone try to read it just yet.

It also adds an analogous dense_bias_act, which is closer to done. What this gains you over σ.(w*x .+ b) is memory savings. Zygote will by default un-fuse the broadcast, allocating 3 arrays on the forward pass, but in fact we can often over-write the result of w*x, saving 2 copies. This should happen both on CPU and GPU. There is one more copy you could save on the reverse pass, bringing you to 1/2 the memory usage of before, but only if you were sure that the pullback would only be called once. That isn't true for say Zygote.jacobian, and I don't think there's a way to know when it will be safe. So we save 1/3 not 1/2, when inside Zygote.

I say "often" because over-writing w*x only works when the gradient of σ can be written in terms of its output, without saving its input. That's true for tanh and relu and some others, which are ~~explicitly whitelisted here as INPLACE_ACTS. Surely a more extensible method for that could be invented.~~ now handled using https://github.com/JuliaDiff/ChainRulesCore.jl/pull/453 .

This was written before seeing https://github.com/FluxML/NNlibCPU.jl/pull/1 . But they may work well together -- for instance the function dense! there could (after we adjust signatures a little) simply overload a function here, providing a fast path when that package is loaded. Likewise it can overload conv_bias_act! to run a fused activation-and-convolution on the CPU, a bit like the existing NNlibCUDA routine. (From a first glance it looks like dense! has a trait for deciding which functions are in-place-safe, which is good.) Again, not fully baked, but opened now to start discussing.

mcabbott avatar Aug 09 '21 02:08 mcabbott

Best to let some discussion over on https://github.com/FluxML/NNlibCPU.jl/pull/1

We have wanted to move to fused conv for some cudnn fast paths too, but unsure if it's a good idea to have all these different versions

DhairyaLGandhi avatar Aug 09 '21 09:08 DhairyaLGandhi

Is the original message up top still accurate? It looks like the implementation is there. What help is necessary to get this through?

darsnack avatar Jan 12 '22 23:01 darsnack

My memory is that this basically worked, but the performance was disappointing due to https://github.com/JuliaLang/julia/issues/43153 . Writing back into the same x (when safe) saved memory but not time, unless you pirated Base things as suggested there. (Which it looks like I didn't do on this branch?)

Edit: ok I've updated things. I think the most honest benchmark looks like this, and shows a serious improvement from tanh_fast. And one copy saved by bias_act!, but now avoiding a serious slowdown, but still slower than ideal, why 71 allocations?

julia> w, b = rand(Float32, 100, 100), rand(Float32, 100); x = rand(Float32, size(w)...);

julia> @btime gradient((w,x,b) -> sum(abs2, dense_bias_act(tanh, w, x, b)), wr[], $x, $b)  setup=(wr=Ref(randn(Float32,100,100))) evals=1;
  min 44.792 μs, mean 79.901 μs (71 allocations, 198.37 KiB)

julia> @btime gradient((w,x,b) -> sum(abs2, tanh.((w * x) .+ b)), wr[], $x, $b)  setup=(wr=Ref(randn(Float32,100,100))) evals=1;
  min 114.583 μs, mean 158.989 μs (39 allocations, 275.25 KiB)

julia> @btime gradient((w,x,b) -> sum(abs2, tanh_fast.((w * x) .+ b)), wr[], $x, $b)  setup=(wr=Ref(randn(Float32,100,100))) evals=1;
  min 40.125 μs, mean 75.140 μs (39 allocations, 275.25 KiB)

Would be worthwhile to benchmark on other computers. (This is M1 + apple's blas.) And on GPUs. And conv... and ideally in bigger examples, whatever happened to https://github.com/FluxML/FluxBench.jl ?

mcabbott avatar Jan 13 '22 02:01 mcabbott

Rebased at https://github.com/mcabbott/NNlib.jl/tree/bias_act_22 after squashing, but its own tests fail.

mcabbott avatar Sep 02 '22 02:09 mcabbott