CUDA-based RMSNorm for Performance Optimization
In continuation of issue #194..
@erfanzar
The current RMSNorm implementation in rms_norm uses Pallas for TPUs and falls back to basic JAX operations for other platforms. On NVIDIA GPUs, a specialized CUDA kernel could significantly improve performance.
Proposed Solution:
- we can implement a CUDA kernel for
RMSNormforward and backward passes - and use warp-level reductions for efficient computation
- finally, add fused operations to reduce memory bandwidth
- and it will preserve API compatibility with the existing implementation
Expected Benefits:
- 2-3x speedup for RMSNorm operations on NVIDIA GPUs
- reduced memory pressure during normalization operations
- better performance for models with many normalization layers (e.g., LLMs)
import os
import jax
import jax.numpy as jnp
import numpy as np
from jax import custom_vjp
import cupy as cp
# CUDA kernel for RMSNorm forward pass - simplified high-level version
_rmsnorm_fwd_kernel = cp.RawKernel(r'''
extern "C" __global__ void rmsnorm_fwd(
const float* x, // input tensor [B, N]
const float* weight, // scale parameter [N]
float* output, // output tensor [B, N]
float* inv_rms, // inverse RMS for backward pass [B, 1]
const int B, // batch size
const int N, // hidden dimension
const float epsilon // epsilon for numerical stability
) {
// Each block handles one batch element
const int b = blockIdx.x;
if (b >= B) return;
// Step 1: Calculate sum of squares using parallel reduction
__shared__ float sdata[1024]; // Shared memory for reduction
float sum_squared = 0.0f;
// Thread parallel sum of squares calculation
for (int i = threadIdx.x; i < N; i += blockDim.x) {
float val = x[b * N + i];
sum_squared += val * val;
}
// Store in shared memory for reduction
sdata[threadIdx.x] = sum_squared;
__syncthreads();
// Parallel reduction in shared memory
for (int s = blockDim.x/2; s > 0; s >>= 1) {
if (threadIdx.x < s) {
sdata[threadIdx.x] += sdata[threadIdx.x + s];
}
__syncthreads();
}
// Calculate RMS (one thread per batch element)
if (threadIdx.x == 0) {
float mean_squared = sdata[0] / N;
float rms = sqrt(mean_squared + epsilon);
inv_rms[b] = 1.0f / rms;
}
__syncthreads();
// Normalize and scale (all threads)
float rms_inv = inv_rms[b];
for (int i = threadIdx.x; i < N; i += blockDim.x) {
output[b * N + i] = x[b * N + i] * rms_inv * weight[i];
}
}
''', 'rmsnorm_fwd')
# CUDA kernel for RMSNorm backward pass - simplified
_rmsnorm_bwd_kernel = cp.RawKernel(r'''
extern "C" __global__ void rmsnorm_bwd(
const float* grad_output, // gradient from output [B, N]
const float* x, // input tensor [B, N]
const float* weight, // scale parameter [N]
const float* inv_rms, // inverse RMS from forward [B, 1]
float* grad_input, // gradient for input [B, N]
float* grad_weight, // gradient for weight [N]
const int B, // batch size
const int N // hidden dimension
) {
// Basic backward pass implementation
// 1. Compute grad_weight (reduction over batch)
// 2. Compute grad_input with proper chain rule application
// Simplified here for the prototype
extern __shared__ float sdata[];
// Compute grad_input
int b = blockIdx.x;
if (b < B) {
float rms_inv = inv_rms[b];
// First part: direct gradient
for (int i = threadIdx.x; i < N; i += blockDim.x) {
grad_input[b * N + i] = grad_output[b * N + i] * weight[i] * rms_inv;
}
// Second part: indirect gradient through normalization
// Simplified implementation for prototype
}
// Compute grad_weight (only done by the first grid)
if (blockIdx.x == 0) {
for (int i = threadIdx.x; i < N; i += blockDim.x) {
float sum = 0.0f;
for (int b = 0; b < B; b++) {
sum += grad_output[b * N + i] * x[b * N + i] * inv_rms[b];
}
grad_weight[i] = sum;
}
}
}
''', 'rmsnorm_bwd')
# JAX integration for the CUDA kernel
@custom_vjp
def rms_norm_cuda(x, weight, eps=1e-5):
"""RMSNorm with CUDA implementation and custom gradient."""
# Forward pass implementation
B, N = x.shape
# Copy data to GPU
x_dev = cp.asarray(x)
weight_dev = cp.asarray(weight)
# Allocate output memory
output_dev = cp.empty_like(x_dev)
inv_rms_dev = cp.empty((B,), dtype=cp.float32)
# Launch kernel
threads_per_block = min(1024, N)
blocks_per_grid = B
_rmsnorm_fwd_kernel(
grid=(blocks_per_grid,),
block=(threads_per_block,),
args=(x_dev, weight_dev, output_dev, inv_rms_dev, B, N, eps)
)
# Copy result back to CPU/JAX
output = jnp.array(cp.asnumpy(output_dev))
return output
# Define forward pass for VJP
def rms_norm_fwd(x, weight, eps):
"""Forward pass for custom VJP."""
B, N = x.shape
# Same implementation as above, but store context for backward
x_dev = cp.asarray(x)
weight_dev = cp.asarray(weight)
output_dev = cp.empty_like(x_dev)
inv_rms_dev = cp.empty((B,), dtype=cp.float32)
threads_per_block = min(1024, N)
blocks_per_grid = B
_rmsnorm_fwd_kernel(
grid=(blocks_per_grid,),
block=(threads_per_block,),
args=(x_dev, weight_dev, output_dev, inv_rms_dev, B, N, eps)
)
output = jnp.array(cp.asnumpy(output_dev))
ctx = (x_dev, weight_dev, inv_rms_dev, B, N, eps)
return output, ctx
# Define backward pass for VJP
def rms_norm_bwd(ctx, grad_output):
"""Backward pass for custom VJP."""
x_dev, weight_dev, inv_rms_dev, B, N, eps = ctx
grad_output_dev = cp.asarray(grad_output)
grad_input_dev = cp.empty_like(x_dev)
grad_weight_dev = cp.empty_like(weight_dev)
threads_per_block = min(1024, N)
blocks_per_grid = B
_rmsnorm_bwd_kernel(
grid=(blocks_per_grid,),
block=(threads_per_block,),
args=(grad_output_dev, x_dev, weight_dev, inv_rms_dev,
grad_input_dev, grad_weight_dev, B, N),
shared_mem=threads_per_block * 4
)
grad_input = jnp.array(cp.asnumpy(grad_input_dev))
grad_weight = jnp.array(cp.asnumpy(grad_weight_dev))
return grad_input, grad_weight, None
# Register the custom VJP
rms_norm_cuda.defvjp(rms_norm_fwd, rms_norm_bwd)
# Integration with EasyDeL's existing code
def integrated_rms_norm(x, weight, blocksize_x=8, eps=1e-5, prod_dtype=jnp.float32):
"""RMSNorm with platform-specific optimizations."""
# Check if we're on GPU and can use CUDA
platform = jax.extend.backend.get_backend().platform
if platform == "gpu":
# Check if CUDA is available
try:
import cupy
return rms_norm_cuda(x, weight, eps)
except ImportError:
pass
# Fallback to original implementation
return original_rms_norm(x, weight, blocksize_x, eps, prod_dtype)
# Example usage in a transformer model:
def transformer_layer(x, params, training=True):
# Apply RMSNorm before attention
norm_x = integrated_rms_norm(x, params['prenorm_weight'])
# Rest of transformer layer operations
# ...
return output
it is basically a high-level prototype which demonstrates the approach while showing the key technical elements that would be implemented.
Thanks @RohitRathore1, this is awesome!
Sure, I'll soon integrate this into easydel
it should take about 1 or 2 days I'll let u know when it's done.
There is a misconception here. The RMSNorm will lower into the HLO as decomposed Jax ops, but the XLA GPU compiler will fuse the ops together and potentially with more ops. So you might end up with maybe an RMSNorm+Matmul dynamic Triton generated kernel plus autotuned block sizes. That will actually beat this RMSNorm cuda example.