dm_pix icon indicating copy to clipboard operation
dm_pix copied to clipboard

Gaussian blur cpu performance

Open mgoulao opened this issue 2 years ago • 12 comments

I have been doing some experiments with PIX since it allows computing image augmentations in the GPU in contrast to torchvision which computes in the CPU and requires multiple workers to avoid bottlenecks. When performing some very simple timeit examples I observed a very high time when performing a gaussian blur in the CPU. I created a simple Colab notebook to demonstrate these experiments. I even tested transferring the image to CPU before performing the blur but it doesn't seem to make any difference. I was wondering if this is intended and I should not rely on CPU computations at all or if something is yet to be optimized for CPU computation.

mgoulao avatar Sep 10 '22 10:09 mgoulao

Hi @mgoulao, thanks for reaching out! Yeah indeed I've also tested this and is not performing quite well on CPU. Transferring the image to CPU only helps a little, it's a gain of few us over a several ms operation. This is not technically intended, the goal we try to achieve with PIX is to have implementations that perform well on TPUs/GPUs, taking what we get as a result of this when running on CPUs. This doesn't mean, of course, that we don't want/have to improve CPU implementation as well 😄 Feel free to submit a PR with any optimisation for CPU!

claudiofantacci avatar Sep 12 '22 08:09 claudiofantacci

Hi @mgoulao

I made a JAX package for stencil computation that can be used to calculate the gaussian blur.

I checked the performance of kernex vs dm_pix on Colab CPU using the following code. It seems that kernex backed convolution is faster on the CPU for this specific function.

Hope this helps Best.

# !pip install dm_pix
# !pip install kernex 

import jax 
import jax.numpy as jnp
import kernex as kex
import dm_pix
import numpy.testing as npt 

def gaussian_blur(image, sigma, kernel_size):
    x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)
    w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))
    w = jnp.outer(w, w)
    w = w / w.sum()

    @kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same")
    def conv(x):
        return jnp.sum(x * w)    
    
    return conv(image)

sigma = 1.
kernel_size=5


gaussian_blur_pix = jax.jit(lambda x: dm_pix.gaussian_blur(x,sigma, kernel_size))
gaussian_blur_kex = jax.jit(lambda x: gaussian_blur(x, sigma, kernel_size))

x = jax.random.uniform(jax.random.PRNGKey(0), (512,512))
xx = jnp.expand_dims(x, axis=2)
npt.assert_allclose(gaussian_blur_pix(xx)[:,:,0], gaussian_blur_kex(x), atol=1e-5)

# warm up
gaussian_blur_pix(xx)
gaussian_blur_kex(x)

%timeit gaussian_blur_pix(xx).block_until_ready()
%timeit gaussian_blur_kex(x).block_until_ready()

111 ms ± 40 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
11.1 ms ± 3.61 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

On colab GPU its seems that kernex performs a bit better

324 µs ± 111 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
200 µs ± 4.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

ASEM000 avatar Oct 02 '22 19:10 ASEM000

Thanks for reporting this as well! I'm a bit short of time at the moment, and for the whole October I'm afraid. I'll try to have a look asap, or later beginning of November. In the meantime, if you come up with a better implementation that works as well on CPU without extra dependencies, feel free to submit a PR! 🚀

claudiofantacci avatar Oct 03 '22 18:10 claudiofantacci

Noted thanks, I will try to contribute in the coming days. Best.

ASEM000 avatar Oct 03 '22 18:10 ASEM000

Hey,

I implemented dm_pix.gaussian_blur with no extra dependencies in this colab

you can find testing and benchmarking against the depthwise-based implementation. On colab CPU I'm getting the following speed up based on timeit average time for the jitted version of both implementations.

# average time ratio pix/kex for 3x3 kernel
# (64, 64, 1):	        12.17
# (128, 128, 1):	14.70
# (256, 256, 1):	17.38
# (512, 512, 1):	16.37
# (64, 64, 32):	62.64
# (128, 128, 32):	44.88
# (256, 256, 32):	36.19
# (512, 512, 32):	36.60
# (64, 64, 64):	42.34
# (128, 128, 64):	80.46
# (256, 256, 64):	57.42
# (512, 512, 64):	54.94

for GPU, the speed-up ratio is

# average time ratio pix/kex for 3x3 kernel
# (64, 64, 1):	        1.76
# (128, 128, 1):	1.87
# (256, 256, 1):	1.82
# (512, 512, 1):	1.98
# (64, 64, 32):	1.81
# (128, 128, 32):	2.67
# (256, 256, 32):	2.72
# (512, 512, 32):	5.24
# (64, 64, 64):	2.96
# (128, 128, 64):	1.78
# (256, 256, 64):	3.22
# (512, 512, 64):	8.81

Let me know if it's suitable for a PR

Best.

ASEM000 avatar Oct 06 '22 10:10 ASEM000

Thanks @ASEM000, I'll have a look at it as soon as I can, unfortunately that will probably be end of month 😭

claudiofantacci avatar Oct 06 '22 13:10 claudiofantacci

I just skimmed through the code, so without checking the implementation details. When you say kex there, you mean the new implementation which is without kex or extra dependency. Is this right?

claudiofantacci avatar Oct 06 '22 13:10 claudiofantacci

Yes, you are right; sorry for the typo.

ASEM000 avatar Oct 06 '22 14:10 ASEM000

That's ok. Skimming through, looks good, but please let's resume this EOM so I have more time to look into the code and give proper advices for submitting a PR 😄

claudiofantacci avatar Oct 06 '22 15:10 claudiofantacci

I'm finally back. I'll try to look into this asap!

claudiofantacci avatar Nov 01 '22 09:11 claudiofantacci

Hello, Any updates or feedback?

Additionally, I implemented a Gaussian filter based on FFT depthwise convolution, which should be faster for large kernels. https://github.com/ASEM000/serket/blob/main/serket/nn/blur.py Let me know if you are interested, so I can provide no extra dependencies version

ASEM000 avatar Dec 09 '22 05:12 ASEM000

Hey @ASEM000, I have not forgotten about this 😄 I've been quite busy and should finally be back to normal work regime, I'll try to look at all this asap 🚀

claudiofantacci avatar Dec 20 '22 06:12 claudiofantacci