Flux.jl
Flux.jl copied to clipboard
Use `NNlib.bias_act!`
Uses https://github.com/FluxML/NNlib.jl/pull/457 to speed up & save memory, up to half the memory for a forward pass. Largest savings in the gradient will be for large batch size, and activation functions like identity, relu, tanh whose input need not be stored.
julia> lenet = Chain( # from the model zoo
Conv((5, 5), 1=>6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6=>16, relu),
MaxPool((2, 2)),
Flux.flatten,
Dense(256 => 120, relu),
Dense(120 => 84, relu),
Dense(84 => 10),
);
julia> img = rand32(28, 28, 1, 128);
julia> @btime $lenet($img);
min 867.875 μs, mean 1.434 ms (160 allocations, 5.60 MiB) # before
min 831.500 μs, mean 1.100 ms (149 allocations, 3.31 MiB) # after
julia> @btime gradient(m -> sum(abs2, m($img)), $lenet);
min 7.128 ms, mean 10.280 ms (567 allocations, 14.19 MiB)
min 6.296 ms, mean 6.930 ms (546 allocations, 9.61 MiB)
Closes #2151 which I forgot about.