mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Implement Weight Normalization, addressing issue #1888

Open cavit99 opened this issue 9 months ago • 11 comments

Proposed changes

This PR implements weight normalization for MLX, addressing issue #1888. Weight normalization is a reparameterization technique that decouples the magnitude of a weight tensor from its direction, making optimization more efficient by improving the conditioning of the optimization problem. Is particularly important for audio processing, among other applications.

Key Features

  • Core C++ implementation of mx.weight_norm with optimized paths for different dimensions
  • Python module weight_norm.py with user-friendly API and layer wrappers
  • Proper handling of MLX's channel ordering differences from PyTorch
  • Workaround for the linalg::norm 2-axes limitation
  • Convenience classes for common layer types (Linear, Conv1d, Conv2d)
  • Comprehensive test suite validating mathematical properties and cross-framework compatibility

Implementation Details

Core C++ Implementation

The core weight_norm operation is implemented with three different paths based on the number of axes to normalize over:

  1. Direct path for 1-2 axes using optimized linalg::norm kernels
  2. Reshape-based approach for >2 axes, which:
    • Identifies dimensions to keep vs. normalize
    • Handles special cases: normalizing all dims, keeping one dim, keeping multiple dims
    • Reshapes appropriately to leverage the optimized 2D norm kernel
    • Reshapes results back for broadcasting

Python Layer

The Python implementation:

  • Provides a weight_norm function that wraps MLX modules
  • Handles dimension ordering differences for different layer types
  • Computes initial g parameter as the norm of the original weight
  • Overrides the module's forward pass to apply weight normalization on-the-fly
  • Includes convenience classes (WeightNormLinear, WeightNormConv1d, WeightNormConv2d)

Testing and Verification

Testing follows a comprehensive two-pronged approach:

1. Mathematical Property Tests

  • Verify that the normalized weights have the correct norm (equals g)
  • Confirm that the direction of normalized weights matches v
  • Validate that changing g correctly scales the weight norms
  • Test edge cases like normalizing over all dimensions

2. Cross-Framework Verification

  • Compare against PyTorch's weight normalization
  • Test both independent implementations and direct weight transfer
  • Document expected differences between frameworks and how to achieve exact equivalence

3. Performance Benchmarks

download Benchmarked on Apple M3 Max shows MLX outperforms PyTorch MPS:

  • Linear layers (1 axis): 4.90x-5.26x speedup
  • Conv1d layers (2 axes): 1.46x-2.05x speedup
  • Conv2d layers (3 axes): 1.50x-1.76x speedup

Usage Examples

Core API

import mlx.core as mx

# Create parameters
v = mx.random.normal((64, 3, 3))  # Direction tensor
g = mx.random.normal((64, 1, 1))  # Magnitude tensor

# Apply weight normalization
w = mx.weight_norm(v, g, axes=[1, 2])

Module API

import mlx.nn as nn
from mlx.nn.layers.weight_norm import weight_norm

# Apply to existing layer
linear = nn.Linear(10, 20)
linear_wn = weight_norm(linear)

# Use convenience class
conv1d_wn = nn.WeightNormConv1d(16, 32, kernel_size=3)

Resolves #1888.

Checklist

  • [ X] I have read the CONTRIBUTING document
  • [X ] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [X ] I have added tests that prove my fix is effective or that my feature works
  • [ X] I have updated the necessary documentation (if needed)

cavit99 avatar Mar 04 '25 01:03 cavit99

Thanks a lot @cavit99, this is great work!

One tiny nit:

  • Could we change the weight naming to weight_g and weight_v? Makes easier to map from torch and remember.

Blaizzy avatar Mar 04 '25 18:03 Blaizzy

As far as I can tell from intial testing.

This PR does address my issues..

Screenshot 2025-03-04 at 7 11 59 PM

The only difference is that I prefered using torch channel first for loading and transposed the weight at run time.

Because Kokoro has a lot transpose operations(~35) and wanted to avoid bugs.

class WeightNormConv1D(nn.Module):
    """Conv1d with weight normalization"""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 1,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        encode: bool = False,
    ):
        super().__init__()

        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups

        # Initialize weight magnitude (g) and direction (v) vectors
        self.weight_g = mx.ones((out_channels, 1, 1))  # Scalar magnitude per output channel
        self.weight_v = mx.ones(
            (out_channels, kernel_size, in_channels)
        )  # Direction vectors

        self.bias = mx.zeros(in_channels if encode else out_channels) if bias else None

    def __call__(self, x, conv):

        weight = weight_norm(self.weight_v, self.weight_g, dim=0)

        if self.bias is not None:
            bias = self.bias.reshape(1, 1, -1)
        try:
                ...
                # Input is channels last, need to transpose weight
                return apply_conv(x, weight.T)
        except Exception as e:
            print(f"Error: {e}")
            print(f"x.shape: {x.shape}, weight.shape: {weight.shape}")
            raise e

Blaizzy avatar Mar 04 '25 18:03 Blaizzy

Thanks a lot @cavit99, this is great work!

One tiny nit:

  • Could we change the weight naming to weight_g and weight_v? Makes easier to map from torch and remember.

agreed from my side, so I pushed that change to the PR, thank you

cavit99 avatar Mar 04 '25 18:03 cavit99

Perfect! 🤩

Now we wait for @awni :)

Blaizzy avatar Mar 04 '25 19:03 Blaizzy

he's gonna look and say meh, maybe if you stick it in normalization.py isn't he

cavit99 avatar Mar 04 '25 19:03 cavit99

I'm not certain about including this as WeightNorm as I thought WeightNorm is not used so much anymore.. thoughts?

Either way we should not make free functions in C++ and Python for this. It should just be a layer in mlx.nn. In the same way we don't have mx.batch_norm but we have nn.BatchNorm etc.

awni avatar Mar 17 '25 17:03 awni

I gather it's making a comeback with realtime audio use cases, vocoders, tts because of being lighter than batch norm and where stability and convergence are critical. Other than @Blaizzy's mlx-audio which implemented weight norm manually also, I see torch's weight_norm being used in for example spark-tts, Nvidia NeMo, coqui tts, and maintained in

https://pytorch.org/docs/stable/generated/torch.nn.utils.parametrizations.weight_norm.html#torch.nn.utils.parametrizations.weight_norm

regarding layer of course you're right, if you agree regarding usefulness I will happy refactor it fully into mlx.nn as it should have been from the start

cavit99 avatar Mar 17 '25 18:03 cavit99

regarding layer of course you're right, if you agree regarding usefulness I will happy refactor it fully into mlx.nn as it should have been from the start

That would be great!

awni avatar Mar 17 '25 18:03 awni

@cavit99 are you planning coming back to this PR?

awni avatar Apr 25 '25 13:04 awni

@cavit99 I'm porting SparkTTS and this would be a great addition :)

Turns out 90% of the audio models we ported in MLX-Audio use WeightNorm layer either in the model or the codec.

Blaizzy avatar Apr 28 '25 13:04 Blaizzy

sorry I will do this by EOD tomorrow

cavit99 avatar Apr 28 '25 13:04 cavit99

I'm going to close this as inactive. We're open to revisiting the addition of weight norm in the future.

awni avatar Jul 02 '25 13:07 awni