mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] conv_general differences between gpu, cpu and linux build

Open acsweet opened this issue 8 months ago • 4 comments

Describe the bug There are noticeable differences between conv_general on gpu and cpu, and seems like overflow on the linux build.

To Reproduce There are two code snippets, one with various tests between gpu and cpu, and the second on run on Linux. For the Linux version I compared it to Jax, but the result seems to overflow or yield very large numbers or nans.

import mlx.core as mx

key = mx.random.key(0)

inputs = mx.random.normal((2, 8, 8, 8, 3), dtype=mx.float32, key=key)
kernel = mx.random.normal((2, 3, 3, 3, 3), dtype=mx.float32, key=key)
strides = (2, 2, 2)
mlx_padding = ([0, 0, 0], [1, 1, 1])
dilation_rate = (1, 1, 1)
groups = 1

result = mx.conv_general(
    inputs,
    kernel,
    stride=strides,
    padding=mlx_padding,
    kernel_dilation=dilation_rate,
    input_dilation=1,
    groups=groups,
    flip=False,
)

result_cpu = mx.conv_general(
    inputs,
    kernel,
    stride=strides,
    padding=mlx_padding,
    kernel_dilation=dilation_rate,
    input_dilation=1,
    groups=groups,
    flip=False,
    stream=mx.cpu
)

result_diff = result - result_cpu
print(f'(conv3d) max_diff: {mx.max(result_diff)}')
print(f'(conv3d) total_absolute_diff: {mx.sum(mx.abs(result_diff))}')
# (conv3d) max_diff: 15.351866722106934
# (conv3d) total_absolute_diff: 544.01220703125

inputs = mx.random.normal((2, 10, 10, 3), dtype=mx.float32, key=key)
kernel = mx.random.normal((2, 2, 2, 3), dtype=mx.float32, key=key)
strides = (1, 2)
mlx_padding = ([0, 0], [1, 0])
dilation_rate = (1, 1)
groups = 1

result = mx.conv_general(
    inputs,
    kernel,
    stride=strides,
    padding=mlx_padding,
    kernel_dilation=dilation_rate,
    input_dilation=1,
    groups=groups,
    flip=False,
)

result_cpu = mx.conv_general(
    inputs,
    kernel,
    stride=strides,
    padding=mlx_padding,
    kernel_dilation=dilation_rate,
    input_dilation=1,
    groups=groups,
    flip=False,
    stream=mx.cpu
)

result_diff = result - result_cpu
print(f'(conv2d) max_diff: {mx.max(result_diff)}')
print(f'(conv2d) total_absolute_diff: {mx.sum(mx.abs(result_diff))}')
# (conv2d) max_diff: 2.334087371826172
# (conv2d) total_absolute_diff: 19.29632568359375

Linux comparison with Jax (they should yield similar results as long as the convolution is handled correctly).


from jax import lax
import jax
import jax.numpy as jnp

import mlx.core as mx

lax.ConvDimensionNumbers(
        lhs_spec=(0, 1, 2, 3, 4), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 1, 2, 3, 4)
    )

key = jax.random.key(42)
inputs = jax.random.normal(key, shape=(2, 3, 8, 8, 8), dtype=jnp.float32)
kernel = jax.random.normal(key, shape=(3, 3, 3, 3, 2), dtype=jnp.float32)
strides = (2, 2, 2)
padding = 'same'
dilation_rate = (1, 1, 1)
dimension_numbers = lax.ConvDimensionNumbers(
        lhs_spec=(0, 1, 2, 3, 4), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 1, 2, 3, 4)
    )
feature_group_count = 1

result_jax = jax.lax.conv_general_dilated(
        inputs,
        kernel,
        strides,
        padding,
        rhs_dilation=dilation_rate,
        dimension_numbers=dimension_numbers,
        feature_group_count=feature_group_count,
    )

inputs = mx.array(inputs)
kernel = mx.array(kernel)
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1) # mlx expects channels_last
kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2) # mlx expects kernel with (out_channels, spatial..., in_channels)

strides = (2, 2, 2)
mlx_padding = ([0, 0, 0], [1, 1, 1])
dilation_rate = (1, 1, 1)
groups = 1

result_mlx = mx.conv_general(
            inputs,
            kernel,
            stride=strides,
            padding=mlx_padding,
            kernel_dilation=dilation_rate,
            input_dilation=1,
            groups=groups,
            flip=False,
        )
# move channels first to match jax
result_mlx = result_mlx.transpose(0, -1, *range(1, result_mlx.ndim - 1))

print('result jax', result_jax.shape, result_jax.dtype)
print('result mlx', result_mlx.shape, result_mlx.dtype)

abs_diff = jnp.abs(jnp.array(result_mlx) - result_jax)
max_abs_diff = jnp.max(abs_diff)
total_abs_diff = jnp.sum(abs_diff)
print('max_abs_diff', max_abs_diff)
print('total_abs_diff', total_abs_diff)

# on intel i9-9980HK, Debian Linux 12
# result jax (2, 2, 4, 4, 4) float32
# result mlx (2, 2, 4, 4, 4) mlx.core.float32
# max_abs_diff 2.0947965e+32
# total_abs_diff 1.2051802e+33
# OR 
# max_abs_diff nan
# total_abs_diff nan

# on M4 Pro, Sequoia 15.2
# result jax (2, 2, 4, 4, 4) float32
# result mlx (2, 2, 4, 4, 4) mlx.core.float32
# max_abs_diff 9.536743e-06
# total_abs_diff 0.00023447536

Expected behavior I would expect no overflow (or excessively large numbers) on the linux build, and a much smaller margin of error (if any) between cpu and gpu with mlx. Please let me know though if there are any limitations or reasons regarding this!

Desktop (please complete the following information):

  • MLX Version: 0.24.1
  • OS
    • Mac: M4 Pro, Sequoia 15.2
    • Linux: Intel i9-9980HK, Debian Linux 12

Additional context This is for the automated tests for the mlx backend for Keras (https://github.com/keras-team/keras/issues/19571)

acsweet avatar Apr 02 '25 23:04 acsweet

Thanks for reporting! Yeah it seems indeed like a bug that has to do with the padding. If the padding is set to 0s or 1s then the result is identical between CPU and GPU.

Judging from the nans on linux I would say that the CPU conv is reading out of bounds with the given padding parameter.

@jagrit06

angeloskath avatar Apr 03 '25 09:04 angeloskath

It looks like a bug with how padding_hi is handled in the loop bounds for the slow_conv_Nd functions on the CPU, I will have a fix for this soon!

jagrit06 avatar Apr 03 '25 16:04 jagrit06

If no one has started on this, I'd like to take it. I believe the convolution operator needs to support padding_lo and padding_hi. @jagrit06 @awni

aturker1 avatar Apr 11 '25 12:04 aturker1

Ah, I see. conv_general only supports zero padding, so it doesn't need to take padding_hi into account 😅

aturker1 avatar Apr 12 '25 00:04 aturker1

Thank you! This might be safe to close!

results on linux with build for latest mlx release (Intel i9-9980HK, Ubuntu 24.04.2 LTS):

(conv3d) max_diff: 0.0
(conv3d) total_absolute_diff: 0.0
(conv2d) max_diff: 0.0
(conv2d) total_absolute_diff: 0.0

results on mac (mlx 0.25.2, M4 Pro, 15.2):

(conv3d) max_diff: 2.86102294921875e-06
(conv3d) total_absolute_diff: 5.8397650718688965e-05
(conv2d) max_diff: 0.0
(conv2d) total_absolute_diff: 0.0

Are the mac precision discrepancies on conv3d within reason?

acsweet avatar May 14 '25 18:05 acsweet

Those results look within a reasonable numerical tolerance to me. Thanks for checking! Will close this.

awni avatar May 14 '25 18:05 awni