NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

Fast Convolutions and Performance in NNlib

Open jessebett opened this issue 5 years ago • 6 comments

Convolutions provided by the FastConv package

Described in their paper is considerably outperforming the back ends for 1D and 2D convolutions. At least on CPU.

using FastConv
using NNlib
using BenchmarkTools

x = randn(500,500,1,1)
spatial_dims = (5,5)
k = randn(spatial_dims...,1,1)

cdims = DenseConvDims(x,k; padding= spatial_dims .-1)

fast_y = @btime convn(x,k);
# 9.582 ms (8 allocations: 1.94 MiB)

nnlib_y = @btime conv(x,k,cdims);
#244.020 ms (33 allocations: 5.80 MiB) 

nnlib_im2col_y = @btime NNlib.conv_im2col(x,k,cdims);
#10.453 ms (50 allocations: 50.39 MiB)

isapprox(fast_y,nnlib_y,atol=1e-3)
#true

jessebett avatar Oct 22 '19 21:10 jessebett

As expected, GPU is a completely different story. FastConv can be extended which I've done in this PR to FastConv.jl.

However, FastConv requires scalar getindex operations, which slow things down considerably. NNlib on GPU is clearly fine. So this issue is about the implementation of convolutions on CPU.

using CuArrays
using FastConv
using NNlib
using BenchmarkTools


x = randn(500,500,1,1) |> cu
spatial_dims = (5,5)
k = randn(spatial_dims...,1,1) |>cu

cdims = DenseConvDims(x,k; padding= spatial_dims .-1)

#NN lib on GPU
@btime CuArrays.@sync conv(x,k,cdims);
# 224.942 μs (87 allocations: 4.05 KiB)

# Fast Conv on CPU
@btime convn($(collect(x)),$(collect(k)));
# 8.431 ms (8 allocations: 992.86 KiB)

jessebett avatar Oct 23 '19 17:10 jessebett

cc @staticfloat

When I looked at ImageFiltering.jl, the main issue was that it wasn't designed to support / scale well across large channel dimensions. Not sure if FastConv has the same issue; given the speedup here it could be useful to just dispatch to for the cases where it makes sense.

Should be quite easy to hook up for anyone interested.

MikeInnes avatar Oct 24 '19 11:10 MikeInnes

@staticfloat I was gonna try comparing with multiple channels and batches but it looks like you're right and FastConv doesn't have this (or possibly I'm using it incorrectly):

using FastConv
using NNlib
using BenchmarkTools

# 3 Channels, 1 Batch
x = randn(500,500,3,1);
spatial_dims = (5,5);
k = randn(spatial_dims...,3,1);
cdims = DenseConvDims(x,k; padding= spatial_dims .-1);

fast_y = convn(x,k);
nnlib_y = conv(x,k,cdims);

fast_y |>size  # (504, 504, 5, 1)
nnlib_y |>size # (504, 504, 1, 1)

jessebett avatar Oct 24 '19 17:10 jessebett

Yeah, FastConv doesn't support multiple channels; what you're doing is instead doing a 3d convolution, so it ends up increasing the 3rd dimension to 3 + 3 - 1. Difficult to compare apples to apples here.

staticfloat avatar Oct 24 '19 17:10 staticfloat

You may be interested in https://github.com/FluxML/NNlib.jl/pull/142

staticfloat avatar Nov 14 '19 09:11 staticfloat

@staticfloat nice!

jessebett avatar Nov 14 '19 17:11 jessebett