Implement Weight Normalization, addressing issue #1888
Proposed changes
This PR implements weight normalization for MLX, addressing issue #1888. Weight normalization is a reparameterization technique that decouples the magnitude of a weight tensor from its direction, making optimization more efficient by improving the conditioning of the optimization problem. Is particularly important for audio processing, among other applications.
Key Features
- Core C++ implementation of
mx.weight_normwith optimized paths for different dimensions - Python module
weight_norm.pywith user-friendly API and layer wrappers - Proper handling of MLX's channel ordering differences from PyTorch
- Workaround for the
linalg::norm2-axes limitation - Convenience classes for common layer types (Linear, Conv1d, Conv2d)
- Comprehensive test suite validating mathematical properties and cross-framework compatibility
Implementation Details
Core C++ Implementation
The core weight_norm operation is implemented with three different paths based on the number of axes to normalize over:
- Direct path for 1-2 axes using optimized
linalg::normkernels - Reshape-based approach for >2 axes, which:
- Identifies dimensions to keep vs. normalize
- Handles special cases: normalizing all dims, keeping one dim, keeping multiple dims
- Reshapes appropriately to leverage the optimized 2D norm kernel
- Reshapes results back for broadcasting
Python Layer
The Python implementation:
- Provides a
weight_normfunction that wraps MLX modules - Handles dimension ordering differences for different layer types
- Computes initial g parameter as the norm of the original weight
- Overrides the module's forward pass to apply weight normalization on-the-fly
- Includes convenience classes (WeightNormLinear, WeightNormConv1d, WeightNormConv2d)
Testing and Verification
Testing follows a comprehensive two-pronged approach:
1. Mathematical Property Tests
- Verify that the normalized weights have the correct norm (equals g)
- Confirm that the direction of normalized weights matches v
- Validate that changing g correctly scales the weight norms
- Test edge cases like normalizing over all dimensions
2. Cross-Framework Verification
- Compare against PyTorch's weight normalization
- Test both independent implementations and direct weight transfer
- Document expected differences between frameworks and how to achieve exact equivalence
3. Performance Benchmarks
Benchmarked on Apple M3 Max shows MLX outperforms PyTorch MPS:
- Linear layers (1 axis): 4.90x-5.26x speedup
- Conv1d layers (2 axes): 1.46x-2.05x speedup
- Conv2d layers (3 axes): 1.50x-1.76x speedup
Usage Examples
Core API
import mlx.core as mx
# Create parameters
v = mx.random.normal((64, 3, 3)) # Direction tensor
g = mx.random.normal((64, 1, 1)) # Magnitude tensor
# Apply weight normalization
w = mx.weight_norm(v, g, axes=[1, 2])
Module API
import mlx.nn as nn
from mlx.nn.layers.weight_norm import weight_norm
# Apply to existing layer
linear = nn.Linear(10, 20)
linear_wn = weight_norm(linear)
# Use convenience class
conv1d_wn = nn.WeightNormConv1d(16, 32, kernel_size=3)
Resolves #1888.
Checklist
- [ X] I have read the CONTRIBUTING document
- [X ] I have run
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes - [X ] I have added tests that prove my fix is effective or that my feature works
- [ X] I have updated the necessary documentation (if needed)
Thanks a lot @cavit99, this is great work!
One tiny nit:
- Could we change the weight naming to
weight_gandweight_v? Makes easier to map from torch and remember.
As far as I can tell from intial testing.
This PR does address my issues..
The only difference is that I prefered using torch channel first for loading and transposed the weight at run time.
Because Kokoro has a lot transpose operations(~35) and wanted to avoid bugs.
class WeightNormConv1D(nn.Module):
"""Conv1d with weight normalization"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
encode: bool = False,
):
super().__init__()
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
# Initialize weight magnitude (g) and direction (v) vectors
self.weight_g = mx.ones((out_channels, 1, 1)) # Scalar magnitude per output channel
self.weight_v = mx.ones(
(out_channels, kernel_size, in_channels)
) # Direction vectors
self.bias = mx.zeros(in_channels if encode else out_channels) if bias else None
def __call__(self, x, conv):
weight = weight_norm(self.weight_v, self.weight_g, dim=0)
if self.bias is not None:
bias = self.bias.reshape(1, 1, -1)
try:
...
# Input is channels last, need to transpose weight
return apply_conv(x, weight.T)
except Exception as e:
print(f"Error: {e}")
print(f"x.shape: {x.shape}, weight.shape: {weight.shape}")
raise e
Thanks a lot @cavit99, this is great work!
One tiny nit:
- Could we change the weight naming to
weight_gandweight_v? Makes easier to map from torch and remember.
agreed from my side, so I pushed that change to the PR, thank you
Perfect! 🤩
Now we wait for @awni :)
he's gonna look and say meh, maybe if you stick it in normalization.py isn't he
I'm not certain about including this as WeightNorm as I thought WeightNorm is not used so much anymore.. thoughts?
Either way we should not make free functions in C++ and Python for this. It should just be a layer in mlx.nn. In the same way we don't have mx.batch_norm but we have nn.BatchNorm etc.
I gather it's making a comeback with realtime audio use cases, vocoders, tts because of being lighter than batch norm and where stability and convergence are critical. Other than @Blaizzy's mlx-audio which implemented weight norm manually also, I see torch's weight_norm being used in for example spark-tts, Nvidia NeMo, coqui tts, and maintained in
https://pytorch.org/docs/stable/generated/torch.nn.utils.parametrizations.weight_norm.html#torch.nn.utils.parametrizations.weight_norm
regarding layer of course you're right, if you agree regarding usefulness I will happy refactor it fully into mlx.nn as it should have been from the start
regarding layer of course you're right, if you agree regarding usefulness I will happy refactor it fully into mlx.nn as it should have been from the start
That would be great!
@cavit99 are you planning coming back to this PR?
@cavit99 I'm porting SparkTTS and this would be a great addition :)
Turns out 90% of the audio models we ported in MLX-Audio use WeightNorm layer either in the model or the codec.
sorry I will do this by EOD tomorrow
I'm going to close this as inactive. We're open to revisiting the addition of weight norm in the future.