EasyDeL icon indicating copy to clipboard operation
EasyDeL copied to clipboard

CUDA-based RMSNorm for Performance Optimization

Open RohitRathore1 opened this issue 7 months ago • 2 comments

In continuation of issue #194..

@erfanzar

The current RMSNorm implementation in rms_norm uses Pallas for TPUs and falls back to basic JAX operations for other platforms. On NVIDIA GPUs, a specialized CUDA kernel could significantly improve performance.

Proposed Solution:

  • we can implement a CUDA kernel for RMSNorm forward and backward passes
  • and use warp-level reductions for efficient computation
  • finally, add fused operations to reduce memory bandwidth
  • and it will preserve API compatibility with the existing implementation

Expected Benefits:

  • 2-3x speedup for RMSNorm operations on NVIDIA GPUs
  • reduced memory pressure during normalization operations
  • better performance for models with many normalization layers (e.g., LLMs)
import os
import jax
import jax.numpy as jnp
import numpy as np
from jax import custom_vjp
import cupy as cp

# CUDA kernel for RMSNorm forward pass - simplified high-level version
_rmsnorm_fwd_kernel = cp.RawKernel(r'''
extern "C" __global__ void rmsnorm_fwd(
    const float* x,          // input tensor [B, N]
    const float* weight,     // scale parameter [N]
    float* output,           // output tensor [B, N]  
    float* inv_rms,          // inverse RMS for backward pass [B, 1]
    const int B,             // batch size
    const int N,             // hidden dimension
    const float epsilon      // epsilon for numerical stability
) {
    // Each block handles one batch element
    const int b = blockIdx.x;
    if (b >= B) return;
    
    // Step 1: Calculate sum of squares using parallel reduction
    __shared__ float sdata[1024]; // Shared memory for reduction
    float sum_squared = 0.0f;
    
    // Thread parallel sum of squares calculation
    for (int i = threadIdx.x; i < N; i += blockDim.x) {
        float val = x[b * N + i];
        sum_squared += val * val;
    }
    
    // Store in shared memory for reduction
    sdata[threadIdx.x] = sum_squared;
    __syncthreads();
    
    // Parallel reduction in shared memory
    for (int s = blockDim.x/2; s > 0; s >>= 1) {
        if (threadIdx.x < s) {
            sdata[threadIdx.x] += sdata[threadIdx.x + s];
        }
        __syncthreads();
    }
    
    // Calculate RMS (one thread per batch element)
    if (threadIdx.x == 0) {
        float mean_squared = sdata[0] / N;
        float rms = sqrt(mean_squared + epsilon);
        inv_rms[b] = 1.0f / rms;
    }
    __syncthreads();
    
    // Normalize and scale (all threads)
    float rms_inv = inv_rms[b];
    for (int i = threadIdx.x; i < N; i += blockDim.x) {
        output[b * N + i] = x[b * N + i] * rms_inv * weight[i];
    }
}
''', 'rmsnorm_fwd')

# CUDA kernel for RMSNorm backward pass - simplified
_rmsnorm_bwd_kernel = cp.RawKernel(r'''
extern "C" __global__ void rmsnorm_bwd(
    const float* grad_output,    // gradient from output [B, N]
    const float* x,              // input tensor [B, N]
    const float* weight,         // scale parameter [N]
    const float* inv_rms,        // inverse RMS from forward [B, 1]
    float* grad_input,           // gradient for input [B, N]
    float* grad_weight,          // gradient for weight [N]
    const int B,                 // batch size
    const int N                  // hidden dimension
) {
    // Basic backward pass implementation
    // 1. Compute grad_weight (reduction over batch)
    // 2. Compute grad_input with proper chain rule application
    
    // Simplified here for the prototype
    extern __shared__ float sdata[];
    
    // Compute grad_input
    int b = blockIdx.x;
    if (b < B) {
        float rms_inv = inv_rms[b];
        
        // First part: direct gradient
        for (int i = threadIdx.x; i < N; i += blockDim.x) {
            grad_input[b * N + i] = grad_output[b * N + i] * weight[i] * rms_inv;
        }
        
        // Second part: indirect gradient through normalization
        // Simplified implementation for prototype
    }
    
    // Compute grad_weight (only done by the first grid)
    if (blockIdx.x == 0) {
        for (int i = threadIdx.x; i < N; i += blockDim.x) {
            float sum = 0.0f;
            for (int b = 0; b < B; b++) {
                sum += grad_output[b * N + i] * x[b * N + i] * inv_rms[b];
            }
            grad_weight[i] = sum;
        }
    }
}
''', 'rmsnorm_bwd')

# JAX integration for the CUDA kernel
@custom_vjp
def rms_norm_cuda(x, weight, eps=1e-5):
    """RMSNorm with CUDA implementation and custom gradient."""
    # Forward pass implementation
    B, N = x.shape
    
    # Copy data to GPU
    x_dev = cp.asarray(x)
    weight_dev = cp.asarray(weight)
    
    # Allocate output memory
    output_dev = cp.empty_like(x_dev)
    inv_rms_dev = cp.empty((B,), dtype=cp.float32)
    
    # Launch kernel
    threads_per_block = min(1024, N)
    blocks_per_grid = B
    
    _rmsnorm_fwd_kernel(
        grid=(blocks_per_grid,), 
        block=(threads_per_block,),
        args=(x_dev, weight_dev, output_dev, inv_rms_dev, B, N, eps)
    )
    
    # Copy result back to CPU/JAX
    output = jnp.array(cp.asnumpy(output_dev))
    return output

# Define forward pass for VJP
def rms_norm_fwd(x, weight, eps):
    """Forward pass for custom VJP."""
    B, N = x.shape
    
    # Same implementation as above, but store context for backward
    x_dev = cp.asarray(x)
    weight_dev = cp.asarray(weight)
    output_dev = cp.empty_like(x_dev)
    inv_rms_dev = cp.empty((B,), dtype=cp.float32)
    
    threads_per_block = min(1024, N)
    blocks_per_grid = B
    
    _rmsnorm_fwd_kernel(
        grid=(blocks_per_grid,), 
        block=(threads_per_block,),
        args=(x_dev, weight_dev, output_dev, inv_rms_dev, B, N, eps)
    )
    
    output = jnp.array(cp.asnumpy(output_dev))
    ctx = (x_dev, weight_dev, inv_rms_dev, B, N, eps)
    return output, ctx

# Define backward pass for VJP
def rms_norm_bwd(ctx, grad_output):
    """Backward pass for custom VJP."""
    x_dev, weight_dev, inv_rms_dev, B, N, eps = ctx
    
    grad_output_dev = cp.asarray(grad_output)
    grad_input_dev = cp.empty_like(x_dev)
    grad_weight_dev = cp.empty_like(weight_dev)
    
    threads_per_block = min(1024, N)
    blocks_per_grid = B
    
    _rmsnorm_bwd_kernel(
        grid=(blocks_per_grid,), 
        block=(threads_per_block,),
        args=(grad_output_dev, x_dev, weight_dev, inv_rms_dev, 
              grad_input_dev, grad_weight_dev, B, N),
        shared_mem=threads_per_block * 4
    )
    
    grad_input = jnp.array(cp.asnumpy(grad_input_dev))
    grad_weight = jnp.array(cp.asnumpy(grad_weight_dev))
    
    return grad_input, grad_weight, None

# Register the custom VJP
rms_norm_cuda.defvjp(rms_norm_fwd, rms_norm_bwd)

# Integration with EasyDeL's existing code
def integrated_rms_norm(x, weight, blocksize_x=8, eps=1e-5, prod_dtype=jnp.float32):
    """RMSNorm with platform-specific optimizations."""
    # Check if we're on GPU and can use CUDA
    platform = jax.extend.backend.get_backend().platform
    
    if platform == "gpu":
        # Check if CUDA is available
        try:
            import cupy
            return rms_norm_cuda(x, weight, eps)
        except ImportError:
            pass
    
    # Fallback to original implementation
    return original_rms_norm(x, weight, blocksize_x, eps, prod_dtype)

# Example usage in a transformer model:
def transformer_layer(x, params, training=True):
    # Apply RMSNorm before attention
    norm_x = integrated_rms_norm(x, params['prenorm_weight'])
    
    # Rest of transformer layer operations
    # ...
    
    return output

it is basically a high-level prototype which demonstrates the approach while showing the key technical elements that would be implemented.

RohitRathore1 avatar May 21 '25 05:05 RohitRathore1

Thanks @RohitRathore1, this is awesome!

Sure, I'll soon integrate this into easydel

it should take about 1 or 2 days I'll let u know when it's done.

erfanzar avatar May 23 '25 17:05 erfanzar

There is a misconception here. The RMSNorm will lower into the HLO as decomposed Jax ops, but the XLA GPU compiler will fuse the ops together and potentially with more ops. So you might end up with maybe an RMSNorm+Matmul dynamic Triton generated kernel plus autotuned block sizes. That will actually beat this RMSNorm cuda example.

patrick-toulme avatar Jul 07 '25 02:07 patrick-toulme