mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Feature Request: Add Weight Normalization Support (weight_norm)

Open Blaizzy opened this issue 10 months ago • 13 comments

MLX currently lacks built-in support for weight normalization, which is a crucial feature for various deep learning architectures, particularly in audio processing and generative models. Weight normalization is a reparameterization technique that decouples the magnitude and direction of weight vectors, often leading to better conditioning and faster convergence. Current Situation:

  • No built-in equivalent to PyTorch's torch.nn.utils.weight_norm
  • Users need to implement custom solutions, which may not be optimal or consistent

Proposed Solution: I've developed a reference implementation that could serve as a starting point:

import mlx.core as mx
import numpy as np
from typing import Optional, List, Union, Tuple

def compute_norm(x: mx.array, 
                p: int, 
                dim: Optional[Union[int, List[int]]] = None, 
                keepdim: bool = False) -> mx.array:
    """
    Compute the p-norm of a tensor along specified dimensions.
    
    Args:
        x: Input array
        p: Order of the norm (1 or 2)
        dim: Dimension(s) along which to compute the norm
        keepdim: Whether to keep the reduced dimensions
    
    Returns:
        MLX array containing the computed norm
    """
    if p not in [1, 2]:
        raise ValueError("Only p-norms with p of 1 or 2 are supported")
    
    # Handle dimension input
    if dim is None:
        dim = tuple(range(x.ndim))
    elif isinstance(dim, int):
        dim = (dim,)
    
    if p == 1:
        # L1 norm
        return mx.sum(mx.abs(x), axis=dim, keepdims=keepdim)
    else:
        # L2 norm
        return mx.sqrt(mx.sum(x * x, axis=dim, keepdims=keepdim))

def weight_norm(weight_v: mx.array, 
                weight_g: mx.array, 
                dim: Optional[int] = None) -> mx.array:
    """
    Applies weight normalization to the input tensor.
    
    Weight normalization reparameterizes weight vectors in a neural network 
    as a magnitude scalar times a direction vector: w = g * v/||v||
    
    Args:
        weight_v: Weight direction tensor (v)
        weight_g: Weight magnitude tensor (g)
        dim: Dimension along which to normalize. If None, normalize over all dims
            except dim=-1
    
    Returns:
        Normalized weight tensor
    """
    rank = len(weight_v.shape)
    
    if dim is not None:
        # Adjust negative dim
        if dim < -1:
            dim += rank
            
        # Create list of axes to normalize over
        axes = list(range(rank))
        if dim != -1:
            axes.remove(dim)
    else:
        # Default behavior: normalize over all dimensions
        axes = list(range(rank))
    
    # Compute L2 norm of v along specified axes
    norm_v = compute_norm(weight_v, p=2, dim=axes, keepdim=True)
    
    # Normalize and scale by g: w = g * (v / ||v||)
    normalized_weight = weight_v / (norm_v + 1e-7)  # Add epsilon for numerical stability
    return normalized_weight * weight_g

# Example usage:
def test_weight_norm():
    # Create sample tensors
    v = mx.random.normal((64, 3, 3))  # Direction tensor
    g = mx.random.normal((64, 1, 1))  # Magnitude tensor
    
    # Apply weight normalization
    w = weight_norm(v, g, dim=0)
    
    # Verify shape
    assert w.shape == v.shape
    
    # Verify norm along specified dimension
    norm_w = compute_norm(w, p=2, dim=[1, 2], keepdim=True)
    mx.eval(norm_w)  # Force computation
    
    return w, norm_w

if __name__ == "__main__":
    normalized_weight, weight_norm = test_weight_norm()

Blaizzy avatar Feb 19 '25 23:02 Blaizzy

Is the reference code you posted working as expected? If not, what's the issue with it?

awni avatar Feb 19 '25 23:02 awni

It works well for Linear, but I can't get it to match the output torch for Conv layers from torch, even if I override the weights with MLX ones.

The Conv are my focus.

Blaizzy avatar Feb 20 '25 00:02 Blaizzy

Example usage with conv1d:

MLX

import mlx.core as mx
import mlx.nn as nn
from typing import Optional, Any
from dataclasses import dataclass
import mlx.core as mx
import numpy as np
from typing import Optional, List, Union, Tuple
from torch.nn.utils import weight_norm
from torch import nn
import torch

# Set seeds for reproducibility
mx.random.seed(42)
torch.manual_seed(42)

def compute_norm(x: mx.array,
                p: int,
                dim: Optional[Union[int, List[int]]] = None,
                keepdim: bool = False) -> mx.array:
    """
    Compute the p-norm of a tensor along specified dimensions.

    Args:
        x: Input array
        p: Order of the norm (1 or 2)
        dim: Dimension(s) along which to compute the norm
        keepdim: Whether to keep the reduced dimensions

    Returns:
        MLX array containing the computed norm
    """
    if p not in [1, 2]:
        raise ValueError("Only p-norms with p of 1 or 2 are supported")

    # Handle dimension input
    if dim is None:
        dim = tuple(range(x.ndim))
    elif isinstance(dim, int):
        dim = (dim,)

    if p == 1:
        # L1 norm
        return mx.sum(mx.abs(x), axis=dim, keepdims=keepdim)
    else:
        # L2 norm
        return mx.sqrt(mx.sum(x * x, axis=dim, keepdims=keepdim))

def weight_norm(weight_v: mx.array,
                weight_g: mx.array,
                dim: Optional[int] = None) -> mx.array:
    """
    Applies weight normalization to the input tensor.

    Weight normalization reparameterizes weight vectors in a neural network
    as a magnitude scalar times a direction vector: w = g * v/||v||

    Args:
        weight_v: Weight direction tensor (v)
        weight_g: Weight magnitude tensor (g)
        dim: Dimension along which to normalize. If None, normalize over all dims
            except dim=-1

    Returns:
        Normalized weight tensor
    """
    rank = len(weight_v.shape)

    if dim is not None:
        # Adjust negative dim
        if dim < -1:
            dim += rank

        # Create list of axes to normalize over
        axes = list(range(rank))
        if dim != -1:
            axes.remove(dim)
    else:
        # Default behavior: normalize over all dimensions
        axes = list(range(rank))

    # Compute L2 norm of v along specified axes
    norm_v = compute_norm(weight_v, p=2, dim=axes, keepdim=True)

    # Normalize and scale by g: w = g * (v / ||v||)
    normalized_weight = weight_v / (norm_v + 1e-7)  # Add epsilon for numerical stability
    return normalized_weight * weight_g


class WeightNormConv1d(nn.Module):
    """Conv1d layer with weight normalization"""
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
                 stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1,
                 bias: bool = True, dim: int = 0, transpose_weight_g: bool = False):
        super().__init__()

        # Initialize weight parameters
        weight_shape_g = (out_channels, 1, 1) if not transpose_weight_g else (in_channels, 1, 1)
        weight_shape_v = (out_channels, in_channels, kernel_size)

        # Store parameters
        self.weight_g = mx.random.normal(weight_shape_g)
        self.weight_v = mx.random.normal(weight_shape_v)
        self.bias = mx.zeros(out_channels) if bias else None
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.dim = dim


        # Compute normalized weight
        self.weight = weight_norm(self.weight_v, self.weight_g, dim=0)


    def __call__(self, x: mx.array) -> mx.array:
        # Apply conv1d
        out = mx.conv1d(x, self.weight, stride=self.stride, padding=self.padding,
                       dilation=self.dilation, groups=self.groups)
        if self.bias is not None:
            out = out + self.bias.reshape(1, 1, -1)
        return out

# Example usage:
layer = WeightNormConv1d(in_channels=20, out_channels=40, kernel_size=3, padding=1)
x = mx.random.normal((16, 20, 3))  # batch_size=16, channels=20, length=3
output = layer(x).swapaxes(1, 2)
print(f"Output shape: {output.shape}")  # Should be (16, 40, 3)
print(f"Weight_g shape: {layer.weight_g.shape}")  # Should be (40, 1, 1)
print(f"Weight_v shape: {layer.weight_v.shape}")  # Should be (40, 20, 3)

torch

m = weight_norm(nn.Conv1d(20, 40, kernel_size=3, padding=1, bias=False), name='weight')
torch_x = torch.from_dlpack(x)
torch_output = m(torch_x)
print(f"Output shape: {torch_output.shape}")  # Should be (16, 40, 3)
print(m)
print(m.weight_g.size())
print(m.weight_v.size())

print(np.allclose(output, torch_output.detach().numpy(), rtol=1e-3, atol=1e-3), output.sum().tolist(), torch_output.detach().numpy().sum())

>> (False, 23.648128509521484, 18.668007)

Blaizzy avatar Feb 20 '25 00:02 Blaizzy

@Blaizzy Inspired by your recent mlx-audio I looked into this issue based on our x convo.

I just submitted a PR that might address the problems you were facing and if merged could hopefully let you utilise weight_norm natively with good performance boost. Here's a summary (beyond what's in the PR)

  1. Dimension Ordering: MLX uses channel-last format while PyTorch uses channel-first. For Conv1d, the weight shapes are:

    • PyTorch: [out_channels, in_channels, kernel_size]
    • MLX: [out_channels, kernel_size, in_channels] This ordering difference needs special handling when normalizing, which I think may be missed in your above script
  2. linalg::norm Limitation: MLX's linalg::norm can only handle up to 2 axes. For Conv2d weights (with 3 axes to normalize), I implemented a reshape-based approach that:

    • Identifies dimensions to keep vs. normalize
    • Reshapes the tensor to use the optimized 2D norm kernel
    • Reshapes back for proper broadcasting
  3. Module Wrapper vs. Custom Layer: Rather than implementing a custom layer from scratch, I used a module wrapper approach that applies weight normalization to existing MLX layers. This:

    • Works with all of MLX's optimized layer implementations
    • Maintains compatibility with MLX's dimension ordering
    • Shows better performance (my benchmarks show 1.5-5x speedup over PyTorch MPS based on realistic audio implementations like yours, haven't benchmarked against your unbound python which uses mx.sum and mx.sqrt instead of mx.linalg.norm, but I suspect there would be improvements)
  4. Testing Approach: When comparing with PyTorch, I found two important insights:

    • Independent Implementations: Using common seeds still shows expected differences (up to ~5.0) which seems normal between frameworks
    • Direct Weight Transfer: Exact equivalence (differences < 1e-5) can be achieved when weights are properly transposed between frameworks I tested both approaches thoroughly, confirming that the mathematical properties are preserved even when numeric values differ slightly. This explains why direct output comparison might fail even when both implementations are correct. I do recommend checking out my test script in particular as it was a good amount of learning for me test_weight_norm.py

My PR includes both core API and module wrapper implementations, along with convenience classes similar to PyTorch's. Hope this helps, and feel free to check out the implementation if it's merged, or adopt as you need regardless

cavit99 avatar Mar 04 '25 02:03 cavit99

Thank you very much @cavit99, indeed it does solve it and matches the torch design! ❤️

MLX: [out_channels, kernel_size, in_channels] This ordering difference needs special handling when normalizing, which I think may be missed in your above script

For a bit of context, I didn't update my implementation here because I was focused on the MLX releases but I did reach out to Awni privately notifying him that I had found the solution and I was going to send a PR after I had rested bit.

My updated solution that shipped with mlx-audio handles it and matches the torch implementation as I specified here: https://github.com/ml-explore/mlx/pull/1921#issuecomment-2698522025

linalg::norm Limitation: MLX's linalg::norm can only handle up to 2 axes. For Conv2d weights (with 3 axes to normalize)...

I like your approach here! I followed the torch implementation of normalization.

https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L5738 https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset9.py#L3276

Finally, the models I'm implementing at the moment only use weighted 1D convs so I didn't think beyod 1D but it's absolute joy knowing that now it will work with 2D out-of-the-box. How about 3D?

Blaizzy avatar Mar 04 '25 18:03 Blaizzy

Cool! sorry a bit of duplicate work then but at least it's got bindings too now!

The weight norm should work with 3D and any higher dims

See:

# From weight_norm.py
elif "Conv" in self.wn_module_type and dim == 0:
    weight_flat = mx.reshape(weight, (weight.shape[0], -1))
    self.weight_g = mx.linalg.norm(weight_flat, axis=1, keepdims=True)
    g_shape = [weight.shape[0]] + [1] * (weight.ndim - 1)
    self.weight_g = mx.reshape(self.weight_g, g_shape)

This code handles any convolution of any dimensionality by:

  1. Reshaping the weight tensor to a 2D matrix (output channels × flattened everything else)
  2. Computing the norm along the flattened dimension
  3. Reshaping the result back to match the original tensor's dimensions

The C++ implementation in ops.cpp explicitly handles higher dimensions:

// If we have more than 2 axes, use the reshape approach
if (norm_axes.size() > 2) {
  // Common case: keep one dimension (e.g., output channels)
  int keep_dim = keep_axes[0];
  std::vector<int> reshape_dims = {v.shape()[keep_dim], -1};
  array v_reshaped = reshape(v, reshape_dims, s);
  
  // Use the 2D norm kernel which is optimized
  array v_norm = linalg::norm(v_reshaped, std::vector<int>{1}, true, s);
}

The metal shaders optimized for 1D/2D data. By reshaping to a 2D tensor with output channels as one dimension and everything else flattened into the second dimension, we achieve the same mathematical result while using optimized kernels.

cavit99 avatar Mar 04 '25 19:03 cavit99

No worries, I think you did a great job and gave it the necessary attention!

Blaizzy avatar Mar 04 '25 19:03 Blaizzy

Thank you very much once again!

Blaizzy avatar Mar 04 '25 19:03 Blaizzy

Hi @Blaizzy is this merged? For inference on a weight normed Conv2D layer isn't it as simple as the below?:

weight_v_norm = mx.sqrt(mx.sum(self.weight_v ** 2, axis=(1, 2, 3), keepdims=True))
weight = self.weight_g * self.weight_v / weight_v_norm
y = mx.conv2d(x,weight,stride=self.stride,padding=self.padding)
if(self.bias is not None):
            y = y + self.bias

I'm not quite sure about how backprop works in MLX but if we can register weight_g and weight_v as parameters then the same layer can also be trained in a model training loop without any changes to the model itself. Maybe @awni can inform as to if I'm missing something major here, since I do feel like it should not be this easy to implement something that appears to be pretty complex from the above discussion. I'm just going by the mathematical definition.

bitanath avatar Aug 13 '25 06:08 bitanath

What you have looks right to me.. though for fast inference it probably makes sense to precompute the normalized weight since it's not changing.

awni avatar Aug 13 '25 10:08 awni

So basically just a hook in a custom layer? Can we just simply override Conv2D to include this? Would it also backprop correctly? I have tested the inference on one of my models replacing nn.utils.weight_norm from PyTorch with this inference call and it seems to work fine. However, I do not know how it would work for training.

def _get_weight(self):
        weight_v_norm = mx.sqrt(mx.sum(self.weight_v ** 2, axis=(1, 2, 3), keepdims=True))
        weight = self.weight_g * self.weight_v / weight_v_norm
        return weight
    
def __call__(self, x):
    weight = self._get_weight()
    y = mx.conv2d(x,weight,stride=self.stride,padding=self.padding)
    if(self.bias is not None):
        y = y + self.bias
    return y

bitanath avatar Aug 14 '25 03:08 bitanath

That is also fine.. the hook will get called every time. If you want to precompute the weight for inference you would want to set the weight normalized weights just once up front. Then use them every time.

awni avatar Aug 14 '25 11:08 awni

Thanks for your response @awni ! MLX is quite clean and easy to implement models in. Definitely hoping for a torch hub like model zoo with common imagenet pretrained backbones for convolutional models/vision transformers.

bitanath avatar Aug 14 '25 15:08 bitanath