Lux.jl icon indicating copy to clipboard operation
Lux.jl copied to clipboard

Suboptimal GroupNorm Implementation on GPUs

Open avik-pal opened this issue 3 years ago • 3 comments

As observed in https://github.com/SciML/FastDEQ.jl/pull/45#issuecomment-1107662055 we get a 2x speedup by moving from GroupNorm to BatchNorm which uses CUDNN kernels.

avik-pal avatar Apr 24 '22 13:04 avik-pal

cuDNN should support this, so it's mostly a matter of hooking up NNlib + NNlibCUDA (unless you're fine with directly calling the CUDA.jl routines here)

ToucheSir avatar Apr 24 '22 19:04 ToucheSir

Actually I dont think CUDNN supports this (at least could figure it out from its documentation). Pytorch uses its own kernel https://github.com/pytorch/pytorch/blob/35d4a805ebc3b6eca1bafb2d332dffa8d0c1fc54/aten/src/ATen/native/cuda/group_norm_kernel.cu

avik-pal avatar May 27 '22 06:05 avik-pal

I must've hallucinated a mention of groups in the docs for cudnnNormalizationForward* then. The PyTorch kernel is quite a beast, so unless someone's up to the task of translating it I think we're stuck with the slower vectorized variant for now. Ideally we would figure out why https://triton-lang.org/master/getting-started/tutorials/05-layer-norm.html is so fast, tweak it to run groupnorm instead and port it to KernelAbstractions or the like. @vchuravy is KA sufficiently high-level to handle such a translation?

ToucheSir avatar May 27 '22 17:05 ToucheSir