candle icon indicating copy to clipboard operation
candle copied to clipboard

LayerNorm Gradient Flow Issue in candle-nn

Open tymat opened this issue 5 months ago • 0 comments

LayerNorm Gradient Flow Issue in candle-nn

Summary

LayerNorm in candle-nn does not properly propagate gradients through all parameters during backpropagation, causing only 33% of parameters to receive gradients. This prevents models from training correctly.

Environment

  • candle-core version: 0.8.0
  • candle-nn version: 0.8.0
  • Platform: macOS with Metal backend (also affects CUDA)
  • Rust version: 1.75+

Minimal Reproducible Example

use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Module, VarBuilder, VarMap, Optimizer};

#[test]
fn test_candle_layernorm_gradient_flow() -> Result<()> {
    // Setup
    let device = Device::Cpu; // Also fails on Metal/CUDA
    let varmap = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
    
    // Create a simple model: Linear -> LayerNorm -> Linear
    let hidden_size = 64;
    let batch_size = 4;
    
    // Build model
    let linear1 = candle_nn::linear(hidden_size, hidden_size, vb.pp("linear1"))?;
    let layer_norm = candle_nn::layer_norm(
        hidden_size, 
        candle_nn::LayerNormConfig::default(), 
        vb.pp("layer_norm")
    )?;
    let linear2 = candle_nn::linear(hidden_size, hidden_size, vb.pp("linear2"))?;
    
    // Create input and target
    let input = Tensor::randn(0f32, 1.0, (batch_size, hidden_size), &device)?;
    let target = Tensor::randn(0f32, 1.0, (batch_size, hidden_size), &device)?;
    
    // Forward pass
    let x1 = linear1.forward(&input)?;
    let x_norm = layer_norm.forward(&x1)?;
    let output = linear2.forward(&x_norm)?;
    
    // Compute loss (MSE)
    let loss = (output.sub(&target))?.sqr()?.mean_all()?;
    
    // Backward pass
    let grads = loss.backward()?;
    
    // Check gradient flow
    let vars = varmap.all_vars();
    let mut params_with_gradients = 0;
    let mut params_without_gradients = 0;
    
    for var in &vars {
        if let Some(grad) = grads.get(var) {
            let grad_norm = grad.sqr()?.sum_all()?.sqrt()?.to_scalar::<f32>()?;
            if grad_norm > 1e-8 {
                params_with_gradients += 1;
            } else {
                params_without_gradients += 1;
            }
        } else {
            params_without_gradients += 1;
        }
    }
    
    let gradient_flow_pct = (params_with_gradients as f32 / vars.len() as f32) * 100.0;
    println!("Gradient flow: {:.1}% ({}/{} parameters)", 
        gradient_flow_pct, params_with_gradients, vars.len());
    
    assert!(gradient_flow_pct > 90.0, 
        "Gradient flow too low: {:.1}% (expected > 90%)", gradient_flow_pct);
    
    Ok(())
}

Expected Behavior

All parameters in the model should receive gradients during backpropagation. The gradient flow percentage should be close to 100%.

Actual Behavior

Only 33.3% of parameters receive gradients (2 out of 6). Specifically:

  • Linear layer parameters receive gradients
  • LayerNorm parameters (weight and bias) do NOT receive gradients

Output

Gradient flow: 33.3% (2/6 parameters)
thread 'test_candle_layernorm_gradient_flow' panicked at 'Gradient flow too low: 33.3% (expected > 90%)'

Impact

This bug prevents models using LayerNorm from training properly, as the normalization parameters cannot be updated. This forces users to implement custom LayerNorm layers with explicit broadcast operations to maintain gradient flow.

Workaround

A custom implementation that preserves gradient flow:

pub struct CustomLayerNorm {
    weight: Tensor,
    bias: Tensor,
    eps: f64,
}

impl CustomLayerNorm {
    pub fn new(normalized_shape: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
        let weight = vb.get_with_hints(normalized_shape, "weight", candle_nn::init::ONES)?;
        let bias = vb.get_with_hints(normalized_shape, "bias", candle_nn::init::ZEROS)?;
        Ok(Self { weight, bias, eps })
    }
}

impl Module for CustomLayerNorm {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let mean = xs.mean_keepdim(D::Minus1)?;
        let xs_centered = xs.broadcast_sub(&mean)?;
        let var = xs_centered.sqr()?.mean_keepdim(D::Minus1)?;
        let std = (var + self.eps)?.sqrt()?;
        let xs_normalized = xs_centered.broadcast_div(&std)?;
        
        // Key difference: using broadcast operations preserves gradient flow
        let xs_scaled = xs_normalized.broadcast_mul(&self.weight)?;
        xs_scaled.broadcast_add(&self.bias)
    }
}

Additional Notes

  • The forward pass computation is correct (verified with numerical equivalence test)
  • Only the backward pass gradient propagation is affected
  • This issue affects all backends (CPU, Metal, CUDA)
  • Using the custom implementation above results in 100% gradient flow

tymat avatar Jun 28 '25 18:06 tymat