candle
candle copied to clipboard
LayerNorm Gradient Flow Issue in candle-nn
LayerNorm Gradient Flow Issue in candle-nn
Summary
LayerNorm in candle-nn does not properly propagate gradients through all parameters during backpropagation, causing only 33% of parameters to receive gradients. This prevents models from training correctly.
Environment
- candle-core version: 0.8.0
- candle-nn version: 0.8.0
- Platform: macOS with Metal backend (also affects CUDA)
- Rust version: 1.75+
Minimal Reproducible Example
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Module, VarBuilder, VarMap, Optimizer};
#[test]
fn test_candle_layernorm_gradient_flow() -> Result<()> {
// Setup
let device = Device::Cpu; // Also fails on Metal/CUDA
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
// Create a simple model: Linear -> LayerNorm -> Linear
let hidden_size = 64;
let batch_size = 4;
// Build model
let linear1 = candle_nn::linear(hidden_size, hidden_size, vb.pp("linear1"))?;
let layer_norm = candle_nn::layer_norm(
hidden_size,
candle_nn::LayerNormConfig::default(),
vb.pp("layer_norm")
)?;
let linear2 = candle_nn::linear(hidden_size, hidden_size, vb.pp("linear2"))?;
// Create input and target
let input = Tensor::randn(0f32, 1.0, (batch_size, hidden_size), &device)?;
let target = Tensor::randn(0f32, 1.0, (batch_size, hidden_size), &device)?;
// Forward pass
let x1 = linear1.forward(&input)?;
let x_norm = layer_norm.forward(&x1)?;
let output = linear2.forward(&x_norm)?;
// Compute loss (MSE)
let loss = (output.sub(&target))?.sqr()?.mean_all()?;
// Backward pass
let grads = loss.backward()?;
// Check gradient flow
let vars = varmap.all_vars();
let mut params_with_gradients = 0;
let mut params_without_gradients = 0;
for var in &vars {
if let Some(grad) = grads.get(var) {
let grad_norm = grad.sqr()?.sum_all()?.sqrt()?.to_scalar::<f32>()?;
if grad_norm > 1e-8 {
params_with_gradients += 1;
} else {
params_without_gradients += 1;
}
} else {
params_without_gradients += 1;
}
}
let gradient_flow_pct = (params_with_gradients as f32 / vars.len() as f32) * 100.0;
println!("Gradient flow: {:.1}% ({}/{} parameters)",
gradient_flow_pct, params_with_gradients, vars.len());
assert!(gradient_flow_pct > 90.0,
"Gradient flow too low: {:.1}% (expected > 90%)", gradient_flow_pct);
Ok(())
}
Expected Behavior
All parameters in the model should receive gradients during backpropagation. The gradient flow percentage should be close to 100%.
Actual Behavior
Only 33.3% of parameters receive gradients (2 out of 6). Specifically:
- Linear layer parameters receive gradients
- LayerNorm parameters (weight and bias) do NOT receive gradients
Output
Gradient flow: 33.3% (2/6 parameters)
thread 'test_candle_layernorm_gradient_flow' panicked at 'Gradient flow too low: 33.3% (expected > 90%)'
Impact
This bug prevents models using LayerNorm from training properly, as the normalization parameters cannot be updated. This forces users to implement custom LayerNorm layers with explicit broadcast operations to maintain gradient flow.
Workaround
A custom implementation that preserves gradient flow:
pub struct CustomLayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
}
impl CustomLayerNorm {
pub fn new(normalized_shape: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get_with_hints(normalized_shape, "weight", candle_nn::init::ONES)?;
let bias = vb.get_with_hints(normalized_shape, "bias", candle_nn::init::ZEROS)?;
Ok(Self { weight, bias, eps })
}
}
impl Module for CustomLayerNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mean = xs.mean_keepdim(D::Minus1)?;
let xs_centered = xs.broadcast_sub(&mean)?;
let var = xs_centered.sqr()?.mean_keepdim(D::Minus1)?;
let std = (var + self.eps)?.sqrt()?;
let xs_normalized = xs_centered.broadcast_div(&std)?;
// Key difference: using broadcast operations preserves gradient flow
let xs_scaled = xs_normalized.broadcast_mul(&self.weight)?;
xs_scaled.broadcast_add(&self.bias)
}
}
Additional Notes
- The forward pass computation is correct (verified with numerical equivalence test)
- Only the backward pass gradient propagation is affected
- This issue affects all backends (CPU, Metal, CUDA)
- Using the custom implementation above results in 100% gradient flow