[BUG] conv_general differences between gpu, cpu and linux build
Describe the bug There are noticeable differences between conv_general on gpu and cpu, and seems like overflow on the linux build.
To Reproduce There are two code snippets, one with various tests between gpu and cpu, and the second on run on Linux. For the Linux version I compared it to Jax, but the result seems to overflow or yield very large numbers or nans.
import mlx.core as mx
key = mx.random.key(0)
inputs = mx.random.normal((2, 8, 8, 8, 3), dtype=mx.float32, key=key)
kernel = mx.random.normal((2, 3, 3, 3, 3), dtype=mx.float32, key=key)
strides = (2, 2, 2)
mlx_padding = ([0, 0, 0], [1, 1, 1])
dilation_rate = (1, 1, 1)
groups = 1
result = mx.conv_general(
inputs,
kernel,
stride=strides,
padding=mlx_padding,
kernel_dilation=dilation_rate,
input_dilation=1,
groups=groups,
flip=False,
)
result_cpu = mx.conv_general(
inputs,
kernel,
stride=strides,
padding=mlx_padding,
kernel_dilation=dilation_rate,
input_dilation=1,
groups=groups,
flip=False,
stream=mx.cpu
)
result_diff = result - result_cpu
print(f'(conv3d) max_diff: {mx.max(result_diff)}')
print(f'(conv3d) total_absolute_diff: {mx.sum(mx.abs(result_diff))}')
# (conv3d) max_diff: 15.351866722106934
# (conv3d) total_absolute_diff: 544.01220703125
inputs = mx.random.normal((2, 10, 10, 3), dtype=mx.float32, key=key)
kernel = mx.random.normal((2, 2, 2, 3), dtype=mx.float32, key=key)
strides = (1, 2)
mlx_padding = ([0, 0], [1, 0])
dilation_rate = (1, 1)
groups = 1
result = mx.conv_general(
inputs,
kernel,
stride=strides,
padding=mlx_padding,
kernel_dilation=dilation_rate,
input_dilation=1,
groups=groups,
flip=False,
)
result_cpu = mx.conv_general(
inputs,
kernel,
stride=strides,
padding=mlx_padding,
kernel_dilation=dilation_rate,
input_dilation=1,
groups=groups,
flip=False,
stream=mx.cpu
)
result_diff = result - result_cpu
print(f'(conv2d) max_diff: {mx.max(result_diff)}')
print(f'(conv2d) total_absolute_diff: {mx.sum(mx.abs(result_diff))}')
# (conv2d) max_diff: 2.334087371826172
# (conv2d) total_absolute_diff: 19.29632568359375
Linux comparison with Jax (they should yield similar results as long as the convolution is handled correctly).
from jax import lax
import jax
import jax.numpy as jnp
import mlx.core as mx
lax.ConvDimensionNumbers(
lhs_spec=(0, 1, 2, 3, 4), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 1, 2, 3, 4)
)
key = jax.random.key(42)
inputs = jax.random.normal(key, shape=(2, 3, 8, 8, 8), dtype=jnp.float32)
kernel = jax.random.normal(key, shape=(3, 3, 3, 3, 2), dtype=jnp.float32)
strides = (2, 2, 2)
padding = 'same'
dilation_rate = (1, 1, 1)
dimension_numbers = lax.ConvDimensionNumbers(
lhs_spec=(0, 1, 2, 3, 4), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 1, 2, 3, 4)
)
feature_group_count = 1
result_jax = jax.lax.conv_general_dilated(
inputs,
kernel,
strides,
padding,
rhs_dilation=dilation_rate,
dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
)
inputs = mx.array(inputs)
kernel = mx.array(kernel)
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1) # mlx expects channels_last
kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2) # mlx expects kernel with (out_channels, spatial..., in_channels)
strides = (2, 2, 2)
mlx_padding = ([0, 0, 0], [1, 1, 1])
dilation_rate = (1, 1, 1)
groups = 1
result_mlx = mx.conv_general(
inputs,
kernel,
stride=strides,
padding=mlx_padding,
kernel_dilation=dilation_rate,
input_dilation=1,
groups=groups,
flip=False,
)
# move channels first to match jax
result_mlx = result_mlx.transpose(0, -1, *range(1, result_mlx.ndim - 1))
print('result jax', result_jax.shape, result_jax.dtype)
print('result mlx', result_mlx.shape, result_mlx.dtype)
abs_diff = jnp.abs(jnp.array(result_mlx) - result_jax)
max_abs_diff = jnp.max(abs_diff)
total_abs_diff = jnp.sum(abs_diff)
print('max_abs_diff', max_abs_diff)
print('total_abs_diff', total_abs_diff)
# on intel i9-9980HK, Debian Linux 12
# result jax (2, 2, 4, 4, 4) float32
# result mlx (2, 2, 4, 4, 4) mlx.core.float32
# max_abs_diff 2.0947965e+32
# total_abs_diff 1.2051802e+33
# OR
# max_abs_diff nan
# total_abs_diff nan
# on M4 Pro, Sequoia 15.2
# result jax (2, 2, 4, 4, 4) float32
# result mlx (2, 2, 4, 4, 4) mlx.core.float32
# max_abs_diff 9.536743e-06
# total_abs_diff 0.00023447536
Expected behavior I would expect no overflow (or excessively large numbers) on the linux build, and a much smaller margin of error (if any) between cpu and gpu with mlx. Please let me know though if there are any limitations or reasons regarding this!
Desktop (please complete the following information):
- MLX Version: 0.24.1
- OS
- Mac: M4 Pro, Sequoia 15.2
- Linux: Intel i9-9980HK, Debian Linux 12
Additional context This is for the automated tests for the mlx backend for Keras (https://github.com/keras-team/keras/issues/19571)
Thanks for reporting! Yeah it seems indeed like a bug that has to do with the padding. If the padding is set to 0s or 1s then the result is identical between CPU and GPU.
Judging from the nans on linux I would say that the CPU conv is reading out of bounds with the given padding parameter.
@jagrit06
It looks like a bug with how padding_hi is handled in the loop bounds for the slow_conv_Nd functions on the CPU, I will have a fix for this soon!
If no one has started on this, I'd like to take it. I believe the convolution operator needs to support padding_lo and padding_hi. @jagrit06 @awni
Ah, I see. conv_general only supports zero padding, so it doesn't need to take padding_hi into account 😅
Thank you! This might be safe to close!
results on linux with build for latest mlx release (Intel i9-9980HK, Ubuntu 24.04.2 LTS):
(conv3d) max_diff: 0.0
(conv3d) total_absolute_diff: 0.0
(conv2d) max_diff: 0.0
(conv2d) total_absolute_diff: 0.0
results on mac (mlx 0.25.2, M4 Pro, 15.2):
(conv3d) max_diff: 2.86102294921875e-06
(conv3d) total_absolute_diff: 5.8397650718688965e-05
(conv2d) max_diff: 0.0
(conv2d) total_absolute_diff: 0.0
Are the mac precision discrepancies on conv3d within reason?
Those results look within a reasonable numerical tolerance to me. Thanks for checking! Will close this.