mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] shapeless matmul isn't

Open josharian opened this issue 3 months ago • 1 comments

Describe the bug

In some cases, a shapeless export of matmul isn't truly shapeless.

To Reproduce

Add this to python/tests/test_export_import.py

    def test_export_matmul_shapeless_mid_dim(self):
        path = os.path.join(self.test_dir, "matmul_shapeless.mlxfn")

        E, H = 64, 17
        arr = mx.arange(E * H, dtype=mx.float32).reshape((E, H)) * (1.0 / (E * H))

        def fn(x):
            return mx.matmul(x, arr)

        sample = mx.zeros((1, 40, E), dtype=mx.float32)
        mx.export_function(path, fn, sample, shapeless=True)
        imported = mx.import_function(path)

        for seq_len in (40, 248, 623):
            with self.subTest(seq_len=seq_len):
                x = mx.arange(seq_len * E, dtype=mx.float32).reshape((1, seq_len, E))
                expected = fn(x)
                (y,) = imported(x)
                self.assertEqual(y.shape, (1, seq_len, H))
                self.assertTrue(mx.allclose(y, expected))

It shapelessly exports a matmul, then re-imports it and confirms the behavior is identical to the original function across different input shapes.

Expected behavior

Want: middle dimension (40) adjusts shapelessly.

Got:

self.assertEqual(y.shape, (1, seq_len, H))
AssertionError: Tuples differ: (1, 40, 17) != (1, 248, 17)

self.assertEqual(y.shape, (1, seq_len, H))
AssertionError: Tuples differ: (1, 40, 17) != (1, 623, 17)

Desktop (please complete the following information):

  • OS Version: macOS 15.6.1 (24G90)
  • Version: main as of 9/16/25 (3f730e77aa3d14e3d52688b8bd6a24bace500166)

Additional context

I maybe have a tentative fix, but my confidence in it is low. Happy to send it if the reviewer is feeling patient. :)

josharian avatar Sep 19 '25 04:09 josharian

First, just a heads up, shapeless compilation/export can fail and it often won't tell you. So it's recommended to use it carefullly:

Use shapeless compilations carefully. Since compilation is not triggered when shapes change, any graphs which are conditional on the input shapes will not work as expected. Shape-dependent computations are common and sometimes subtle to detect...

In your case the actual graph looks like this:

Image

The problem there is the unflatten which needs to know the shape to unflatten to. It gets hardcoded as (1, 40) the first time you run the computation.

There isn't an easy fix for this... it might be possible to remove the flatten/unflatten around a matmul and use broadcast instead.

Another option for you is to squeeze the singleton dimension before you call the matmul (which will avoid inserting the flatten and unflatten in the graph):

awni avatar Sep 20 '25 03:09 awni