dm_pix
dm_pix copied to clipboard
Gaussian blur cpu performance
I have been doing some experiments with PIX since it allows computing image augmentations in the GPU in contrast to torchvision which computes in the CPU and requires multiple workers to avoid bottlenecks. When performing some very simple timeit
examples I observed a very high time when performing a gaussian blur in the CPU. I created a simple Colab notebook to demonstrate these experiments. I even tested transferring the image to CPU before performing the blur but it doesn't seem to make any difference. I was wondering if this is intended and I should not rely on CPU computations at all or if something is yet to be optimized for CPU computation.
Hi @mgoulao, thanks for reaching out! Yeah indeed I've also tested this and is not performing quite well on CPU. Transferring the image to CPU only helps a little, it's a gain of few us over a several ms operation. This is not technically intended, the goal we try to achieve with PIX is to have implementations that perform well on TPUs/GPUs, taking what we get as a result of this when running on CPUs. This doesn't mean, of course, that we don't want/have to improve CPU implementation as well 😄 Feel free to submit a PR with any optimisation for CPU!
Hi @mgoulao
I made a JAX package for stencil computation that can be used to calculate the gaussian blur.
I checked the performance of kernex
vs dm_pix
on Colab CPU using the following code.
It seems that kernex
backed convolution is faster on the CPU for this specific function.
Hope this helps Best.
# !pip install dm_pix
# !pip install kernex
import jax
import jax.numpy as jnp
import kernex as kex
import dm_pix
import numpy.testing as npt
def gaussian_blur(image, sigma, kernel_size):
x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)
w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))
w = jnp.outer(w, w)
w = w / w.sum()
@kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same")
def conv(x):
return jnp.sum(x * w)
return conv(image)
sigma = 1.
kernel_size=5
gaussian_blur_pix = jax.jit(lambda x: dm_pix.gaussian_blur(x,sigma, kernel_size))
gaussian_blur_kex = jax.jit(lambda x: gaussian_blur(x, sigma, kernel_size))
x = jax.random.uniform(jax.random.PRNGKey(0), (512,512))
xx = jnp.expand_dims(x, axis=2)
npt.assert_allclose(gaussian_blur_pix(xx)[:,:,0], gaussian_blur_kex(x), atol=1e-5)
# warm up
gaussian_blur_pix(xx)
gaussian_blur_kex(x)
%timeit gaussian_blur_pix(xx).block_until_ready()
%timeit gaussian_blur_kex(x).block_until_ready()
111 ms ± 40 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
11.1 ms ± 3.61 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
On colab GPU its seems that kernex
performs a bit better
324 µs ± 111 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
200 µs ± 4.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Thanks for reporting this as well! I'm a bit short of time at the moment, and for the whole October I'm afraid. I'll try to have a look asap, or later beginning of November. In the meantime, if you come up with a better implementation that works as well on CPU without extra dependencies, feel free to submit a PR! 🚀
Noted thanks, I will try to contribute in the coming days. Best.
Hey,
I implemented dm_pix.gaussian_blur
with no extra dependencies in this colab
you can find testing and benchmarking against the depthwise-based implementation.
On colab CPU I'm getting the following speed up based on timeit
average time for the jitted version of both implementations.
# average time ratio pix/kex for 3x3 kernel
# (64, 64, 1): 12.17
# (128, 128, 1): 14.70
# (256, 256, 1): 17.38
# (512, 512, 1): 16.37
# (64, 64, 32): 62.64
# (128, 128, 32): 44.88
# (256, 256, 32): 36.19
# (512, 512, 32): 36.60
# (64, 64, 64): 42.34
# (128, 128, 64): 80.46
# (256, 256, 64): 57.42
# (512, 512, 64): 54.94
for GPU, the speed-up ratio is
# average time ratio pix/kex for 3x3 kernel
# (64, 64, 1): 1.76
# (128, 128, 1): 1.87
# (256, 256, 1): 1.82
# (512, 512, 1): 1.98
# (64, 64, 32): 1.81
# (128, 128, 32): 2.67
# (256, 256, 32): 2.72
# (512, 512, 32): 5.24
# (64, 64, 64): 2.96
# (128, 128, 64): 1.78
# (256, 256, 64): 3.22
# (512, 512, 64): 8.81
Let me know if it's suitable for a PR
Best.
Thanks @ASEM000, I'll have a look at it as soon as I can, unfortunately that will probably be end of month 😭
I just skimmed through the code, so without checking the implementation details.
When you say kex
there, you mean the new implementation which is without kex
or extra dependency. Is this right?
Yes, you are right; sorry for the typo.
That's ok. Skimming through, looks good, but please let's resume this EOM so I have more time to look into the code and give proper advices for submitting a PR 😄
I'm finally back. I'll try to look into this asap!
Hello, Any updates or feedback?
Additionally, I implemented a Gaussian filter based on FFT depthwise convolution, which should be faster for large kernels. https://github.com/ASEM000/serket/blob/main/serket/nn/blur.py Let me know if you are interested, so I can provide no extra dependencies version
Hey @ASEM000, I have not forgotten about this 😄 I've been quite busy and should finally be back to normal work regime, I'll try to look at all this asap 🚀