Depthwise convolutions produce a large number of allocations
Depthwise convolutions, which are currently implemented as a standard Conv layer with the number of groups equal to the number of input channels, seem to produce a very large number of allocations compared to the old DepthwiseConv layer. To confirm this, I restored the DepthwiseConv layer removed in #1921 and compared the performance to the current implementation.
Running the code below shows the following:
- Standard Convolution: 42.935 ms (45 allocations: 143.94 MiB)
- Grouped Convolution: 4.326 ms (22555 allocations: 146.40 MiB)
- Depthwise Convolution: 9.008 ms (30 allocations: 143.94 MiB)
Conv with groups=1024 produces around 750 times as many allocations as DepthwiseConv. The result is that CNN architectures which rely on depthwise convolutions produce hundreds of thousands of allocations compared to only a few thousand for comparably sized models with standard convolutions. Is there any reason for this discrepancy?
Note: I am testing this on Julia 1.10 with Flux v0.14.22.
using Flux, BenchmarkTools
struct DepthwiseConv{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
end
function DepthwiseConv(k::NTuple{<:Any,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
stride = 1, pad = 0, dilation = 1, bias = true, init = Flux.glorot_uniform)
Conv(k, ch, σ; groups=ch.first, stride, pad, dilation, bias, init)
end
function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = Flux.expand(Val(N-2), stride)
dilation = Flux.expand(Val(N-2), dilation)
pad = Flux.calc_padding(DepthwiseConv, pad, size(w)[1:N-2], dilation, stride)
b = Flux.create_bias(w, bias, prod(size(w)[N-1:end]))
return DepthwiseConv(σ, w, b, stride, pad, dilation)
end
function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = Flux.glorot_uniform, stride = 1, pad = 0, dilation = 1,
bias = true) where N
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
weight = depthwiseconvfilter(k, ch, init = init)
return DepthwiseConv(weight, bias, σ; stride, pad, dilation)
end
Flux.@functor DepthwiseConv
depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = Flux.glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])
function (c::DepthwiseConv)(x)
σ = NNlib.fast_act(c.σ, x)
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(depthwiseconv(x, c.weight, cdims) .+ Flux.conv_reshape_bias(c))
end
x = rand(Float32, 28, 28, 1024, 1)
conv = Conv((3,3), 1024=>1024, relu, pad=SamePad())
grouped_conv = Conv((3,3), 1024=>1024, relu, pad=SamePad(), groups=1024)
depthwise_conv = DepthwiseConv((3,3), 1024=>1024, relu, pad=SamePad())
@btime conv(x);
@btime grouped_conv(x);
@btime depthwise_conv(x);
I see I'm blamed in https://github.com/FluxML/Flux.jl/pull/1921 for suggesting that change, although I've forgotten why.
With the code above, I see similar numbers to you, grouped_conv is faster but has many small allocations:
julia> depthwise_conv.weight .= grouped_conv.weight;
julia> y1 = @btime grouped_conv(x);
4.430 ms (23577 allocations: 201.13 MiB)
julia> y2 = @btime depthwise_conv(x);
15.584 ms (27 allocations: 199.06 MiB)
julia> y1 ≈ y2
true
Repeating the benchmarks of #1921 today... Flux.DepthwiseConv with groups has many small allocations, more than seen in #1921, although even then it was an increase over before:
julia> x = randn(Float32, 128, 128, 32, 32);
julia> dconv1 = Flux.DepthwiseConv((3,3), 32 => 64) # using groups, after 1921
Conv((3, 3), 32 => 64, groups=32) # 640 parameters
julia> z1 = @btime $dconv1($x);
38.161 ms (1236 allocations: 370.29 MiB)
julia> dconv2 = DepthwiseConv((3,3), 32 => 64); # using code above
julia> copyto!(dconv2.weight, dconv1.weight); # 3×3×2×32 from 3×3×1×64
julia> z2 = @btime $dconv2($x);
45.090 ms (42 allocations: 370.16 MiB)
julia> z1 ≈ z2
true
julia> Threads.nthreads()
4
I think the NNlib CPU conv code remains in need of some care... more and more layers of multi-threading were added & probably ought to be pruned.