Lux.jl
Lux.jl copied to clipboard
Suboptimal GroupNorm Implementation on GPUs
As observed in https://github.com/SciML/FastDEQ.jl/pull/45#issuecomment-1107662055 we get a 2x speedup by moving from GroupNorm to BatchNorm which uses CUDNN kernels.
cuDNN should support this, so it's mostly a matter of hooking up NNlib + NNlibCUDA (unless you're fine with directly calling the CUDA.jl routines here)
Actually I dont think CUDNN supports this (at least could figure it out from its documentation). Pytorch uses its own kernel https://github.com/pytorch/pytorch/blob/35d4a805ebc3b6eca1bafb2d332dffa8d0c1fc54/aten/src/ATen/native/cuda/group_norm_kernel.cu
I must've hallucinated a mention of groups in the docs for cudnnNormalizationForward* then. The PyTorch kernel is quite a beast, so unless someone's up to the task of translating it I think we're stuck with the slower vectorized variant for now. Ideally we would figure out why https://triton-lang.org/master/getting-started/tutorials/05-layer-norm.html is so fast, tweak it to run groupnorm instead and port it to KernelAbstractions or the like. @vchuravy is KA sufficiently high-level to handle such a translation?