fft-conv-pytorch icon indicating copy to clipboard operation
fft-conv-pytorch copied to clipboard

Speed of depth-wise convolution

Open lim1011 opened this issue 5 months ago • 0 comments

Dear Author:

Thank you for your contribution on this work and share it!

The FFTConv2d layer is much faster than torch.nn.Conv2d in normal case. However, when I use depthwise convolution by adding groups=channel to the function, it becomes very slow, and even slower than normal convolution. How can I improve it?

Test case refers to the README.md, and batch_size=8, kernel_size=31, padding=15.

Normal convolution result:

Conv2d:  0.16026759147644043
FFTConv2d:  0.0806875228881836

Depthwise convolution result:

Conv2d:  0.02997612953186035
FFTConv2d:  0.11628437042236328

My test code is:

def speed_test(
    batch_size: int = 8,
    channel: int = 4,
    input_size: int = 512,
    kernel_size: int = 31,
    depthwise = True
):
    if torch.cuda.is_available():
        x = torch.randn(batch_size, channel, input_size, input_size).cuda()
        
        if depthwise:
            conv = nn.Conv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False, groups=channel).cuda()        
            fftconv = FFTConv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False, groups=channel).cuda()
            fftconv.load_state_dict(conv.state_dict())
     
        else:
            conv = nn.Conv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False).cuda()        
            fftconv = FFTConv2d(channel, channel, kernel_size, padding=kernel_size// 2, bias=False).cuda()
            fftconv.load_state_dict(conv.state_dict())       
        
        print("time:")
       
        torch.cuda.synchronize()
        start = time.time()
        y = conv(x)
        torch.cuda.synchronize()
        end = time.time()
        print("Conv2d: ", end - start)


        torch.cuda.synchronize()
        start2 = time.time()
        y2 = fftconv(x)
        torch.cuda.synchronize()
        end2 = time.time()
        print("FFTConv2d: ", end2 - start2)

        print("difference between FFTConv and Conv:", ((y2 - y) ** 2).mean())

lim1011 avatar Aug 28 '24 09:08 lim1011