[BUG] Possible numerical issue with Conv1d
MLX Issue Draft: Numerical Discrepancy in mlx.nn.Conv1d vs. PyTorch within Sequential Context
Description:
When replicating a PyTorch model component involving sequential Snake activations and Conv1d layers (specifically a ResidualUnit from the Descript Audio Codec model) in MLX, we observe a numerical discrepancy between the output of mlx.nn.Conv1d and torch.nn.Conv1d, even when both are run on the CPU backend.
Crucially:
- The inputs fed to the corresponding
Conv1dlayers immediately before the divergence are numerically identical (verified via hooks and manual calculation, MAE ~0). - The layer weights and biases are loaded from the same source PyTorch checkpoint and are numerically identical.
- A standalone test comparing only the
mlx.nn.Conv1doperation againsttorch.nn.Conv1dusing the exact same input/weight/bias tensors (saved from the failing run) passes with high precision (MAE ~1e-7).
The discrepancy only appears when the mlx.nn.Conv1d is executed within the sequence of operations in the ResidualUnit (Snake -> Conv1d -> Snake -> Conv1d -> Add). The mismatch starts immediately after the first Conv1d layer in the sequence. This occurs consistently on both the Metal and CPU backends for MLX.
Environment:
- MLX version: 0.25.0
- Python version: 3.10.14
- OS: macOS Sonoma 14.3 (darwin 24.3.0)
- Hardware: Apple M2 Max
- PyTorch version: 2.7.0 (Inferred from project dependencies/common usage)
Reproducible Example:
- Code:
- MLX
ResidualUnitandSnakeimplementation (mlx_dac/model/layers.py- see attached/link). - Comparison script (
compare_dac_outputs.py- see attached/link). This script loads original DAC weights, sets up hooks for PyTorch, manually computes MLX intermediates, and compares them. - Weight conversion script (
convert_weights.py- see attached/link, if needed).
- MLX
- Weights: Requires the original Descript Audio Codec weights (e.g.,
weights_44khz_8kbps_0.0.1.pth) and the MLX converted version (mlx_dac_44khz.safetensors). - Input Data: A sample audio file (e.g.,
sample_5sec.wav). - Steps:
- Run
python compare_dac_outputs.py. - Observe the comparison results printed for "ResUnit 0 CausalConv Output (Sliced - PT vs MLX)". Note that
np.allclosereturnsFalseand MAE is significant (~0.04). - (Optional) Run
python test_causal_conv1d_nlc.py(uses.npyfiles saved by the comparison script) to see the isolated convolution pass.
- Run
Expected Behavior:
The output of the MLX Conv1d layer within the ResidualUnit sequence should closely match the corresponding PyTorch layer's output when run on the CPU backend with identical inputs and weights, similar to how they match in the isolated test case.
Actual Behavior:
A significant numerical difference (MAE ~0.04) arises specifically after the mlx.nn.Conv1d layer within the ResidualUnit sequence, on both Metal and CPU backends, despite identical inputs/weights and passing isolated tests.
Attachments/Links:
mlx_dac/model/layers.pycompare_dac_outputs.py- Output log from
compare_dac_outputs.pyshowing the mismatch. - (Optionally
.npyfiles for the specific CausalConv input/weights/bias where the isolated test passes but the sequence fails:debug_mlx_causalconv_input_manual.npy,debug_mlx_causalconv_weight.npy,debug_mlx_causalconv_bias.npy)
It seems like there are some links or code missing in order to reproduce the issue. You posted a bunch of steps and the one script compare_dac_outputs.py but a lot of stuff is missing to actually run that. Could you provide the full steps needed to reproduce the issue?
Closing as inactive. Feel free to comment with steps to repro and we can reopen if needed.