Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Use `NNlib.bias_act!`

Open mcabbott opened this issue 2 years ago • 0 comments

Uses https://github.com/FluxML/NNlib.jl/pull/457 to speed up & save memory, up to half the memory for a forward pass. Largest savings in the gradient will be for large batch size, and activation functions like identity, relu, tanh whose input need not be stored.

julia> lenet = Chain(  # from the model zoo
           Conv((5, 5), 1=>6, relu),
           MaxPool((2, 2)),
           Conv((5, 5), 6=>16, relu),
           MaxPool((2, 2)),
           Flux.flatten,
           Dense(256 => 120, relu),
           Dense(120 => 84, relu), 
           Dense(84 => 10),
       );

julia> img = rand32(28, 28, 1, 128);

julia> @btime $lenet($img);
  min 867.875 μs, mean 1.434 ms (160 allocations, 5.60 MiB)  # before
  min 831.500 μs, mean 1.100 ms (149 allocations, 3.31 MiB)  # after

julia> @btime gradient(m -> sum(abs2, m($img)), $lenet);
  min 7.128 ms, mean 10.280 ms (567 allocations, 14.19 MiB)
  min 6.296 ms, mean 6.930 ms (546 allocations, 9.61 MiB)

Closes #2151 which I forgot about.

mcabbott avatar Sep 04 '23 22:09 mcabbott