mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] Possible numerical issue with Conv1d

Open sethdford opened this issue 8 months ago • 1 comments

MLX Issue Draft: Numerical Discrepancy in mlx.nn.Conv1d vs. PyTorch within Sequential Context

Description:

When replicating a PyTorch model component involving sequential Snake activations and Conv1d layers (specifically a ResidualUnit from the Descript Audio Codec model) in MLX, we observe a numerical discrepancy between the output of mlx.nn.Conv1d and torch.nn.Conv1d, even when both are run on the CPU backend.

Crucially:

  1. The inputs fed to the corresponding Conv1d layers immediately before the divergence are numerically identical (verified via hooks and manual calculation, MAE ~0).
  2. The layer weights and biases are loaded from the same source PyTorch checkpoint and are numerically identical.
  3. A standalone test comparing only the mlx.nn.Conv1d operation against torch.nn.Conv1d using the exact same input/weight/bias tensors (saved from the failing run) passes with high precision (MAE ~1e-7).

The discrepancy only appears when the mlx.nn.Conv1d is executed within the sequence of operations in the ResidualUnit (Snake -> Conv1d -> Snake -> Conv1d -> Add). The mismatch starts immediately after the first Conv1d layer in the sequence. This occurs consistently on both the Metal and CPU backends for MLX.

Environment:

  • MLX version: 0.25.0
  • Python version: 3.10.14
  • OS: macOS Sonoma 14.3 (darwin 24.3.0)
  • Hardware: Apple M2 Max
  • PyTorch version: 2.7.0 (Inferred from project dependencies/common usage)

Reproducible Example:

  1. Code:
    • MLX ResidualUnit and Snake implementation (mlx_dac/model/layers.py - see attached/link).
    • Comparison script (compare_dac_outputs.py - see attached/link). This script loads original DAC weights, sets up hooks for PyTorch, manually computes MLX intermediates, and compares them.
    • Weight conversion script (convert_weights.py - see attached/link, if needed).
  2. Weights: Requires the original Descript Audio Codec weights (e.g., weights_44khz_8kbps_0.0.1.pth) and the MLX converted version (mlx_dac_44khz.safetensors).
  3. Input Data: A sample audio file (e.g., sample_5sec.wav).
  4. Steps:
    • Run python compare_dac_outputs.py.
    • Observe the comparison results printed for "ResUnit 0 CausalConv Output (Sliced - PT vs MLX)". Note that np.allclose returns False and MAE is significant (~0.04).
    • (Optional) Run python test_causal_conv1d_nlc.py (uses .npy files saved by the comparison script) to see the isolated convolution pass.

Expected Behavior:

The output of the MLX Conv1d layer within the ResidualUnit sequence should closely match the corresponding PyTorch layer's output when run on the CPU backend with identical inputs and weights, similar to how they match in the isolated test case.

Actual Behavior:

A significant numerical difference (MAE ~0.04) arises specifically after the mlx.nn.Conv1d layer within the ResidualUnit sequence, on both Metal and CPU backends, despite identical inputs/weights and passing isolated tests.

Attachments/Links:

  • mlx_dac/model/layers.py
  • compare_dac_outputs.py
  • Output log from compare_dac_outputs.py showing the mismatch.
  • (Optionally .npy files for the specific CausalConv input/weights/bias where the isolated test passes but the sequence fails: debug_mlx_causalconv_input_manual.npy, debug_mlx_causalconv_weight.npy, debug_mlx_causalconv_bias.npy)

compare_dac_outputs.py.txt

sethdford avatar Apr 25 '25 07:04 sethdford

It seems like there are some links or code missing in order to reproduce the issue. You posted a bunch of steps and the one script compare_dac_outputs.py but a lot of stuff is missing to actually run that. Could you provide the full steps needed to reproduce the issue?

awni avatar Apr 25 '25 13:04 awni

Closing as inactive. Feel free to comment with steps to repro and we can reopen if needed.

awni avatar May 09 '25 17:05 awni