fft-conv-pytorch
fft-conv-pytorch copied to clipboard
Speed of depth-wise convolution
Dear Author:
Thank you for your contribution on this work and share it!
The FFTConv2d layer is much faster than torch.nn.Conv2d in normal case. However, when I use depthwise convolution by adding groups=channel
to the function, it becomes very slow, and even slower than normal convolution. How can I improve it?
Test case refers to the README.md, and batch_size=8, kernel_size=31, padding=15.
Normal convolution result:
Conv2d: 0.16026759147644043
FFTConv2d: 0.0806875228881836
Depthwise convolution result:
Conv2d: 0.02997612953186035
FFTConv2d: 0.11628437042236328
My test code is:
def speed_test(
batch_size: int = 8,
channel: int = 4,
input_size: int = 512,
kernel_size: int = 31,
depthwise = True
):
if torch.cuda.is_available():
x = torch.randn(batch_size, channel, input_size, input_size).cuda()
if depthwise:
conv = nn.Conv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False, groups=channel).cuda()
fftconv = FFTConv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False, groups=channel).cuda()
fftconv.load_state_dict(conv.state_dict())
else:
conv = nn.Conv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False).cuda()
fftconv = FFTConv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False).cuda()
fftconv.load_state_dict(conv.state_dict())
print("time:")
torch.cuda.synchronize()
start = time.time()
y = conv(x)
torch.cuda.synchronize()
end = time.time()
print("Conv2d: ", end - start)
torch.cuda.synchronize()
start2 = time.time()
y2 = fftconv(x)
torch.cuda.synchronize()
end2 = time.time()
print("FFTConv2d: ", end2 - start2)
print("difference between FFTConv and Conv:", ((y2 - y) ** 2).mean())