jax icon indicating copy to clipboard operation
jax copied to clipboard

Adding `jnp.identity` and `jnp.matmul` raises XLA warning and affects performance.

Open tillahoffmann opened this issue 1 year ago • 1 comments

Description

Under specific circumstances, my jitted function raises a warning W external/xla/xla/service/cpu/onednn_matmul.cc:172] [Perf]: MatMul reference implementation being executed.

>>> import jax
>>> from jax import numpy as jnp

>>> @jax.jit
>>> def func_with_warning(y):
...    return jnp.identity(y.shape[-1]) + jnp.matmul(y, y)

>>> func_with_warning(jnp.ones((2, 100, 100))).shape
(2, 100, 100)
2024-02-20 02:31:15.805823: W external/xla/xla/service/cpu/onednn_matmul.cc:172] [Perf]: MatMul reference implementation being executed

The warning is only raised for this specific setup. Turning one of many knobs eliminates the warning. For example, having a batch dimension of size 1 works fine, even if we increase the size of the trailing two dimensions.

>>> func_with_warning(jnp.ones((1, 1000, 1000))).shape
(1, 1000, 1000)
<no warning>

Replacing identity by ones works just fine.

>>> @jax.jit
>>> def fine_func(y):
...    return jnp.ones((y.shape[-1], y.shape[-1])) + jnp.matmul(y, y)

>>> fine_func(jnp.ones((2, 100, 100))).shape
(2, 100, 100)
<no warning>

Summing the identity before addition works fine.

>>> @jax.jit
>>> def fine_func(y):
...    return jnp.identity(y.shape[-1]).sum() + jnp.matmul(y, y)

>>> fine_func(jnp.ones((2, 100, 100))).shape
(2, 100, 100)
<no warning>

Subtracting rather than adding works fine.

>>> @jax.jit
>>> def fine_func(y):
...    x = jnp.identity(y.shape[-1])
...    return x - jnp.matmul(y, y)

>>> fine_func(jnp.ones((2, 100, 100))).shape
(2, 100, 100)
<no warning>

This seems to indeed affect performance.

>>> batch = jnp.ones((10, 100, 100))
>>> # Run the jitted function on the batch.
>>> %timeit func_with_warning(batch).block_until_ready()
81.4 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> # List comprehension in Python.
>>> %timeit [func_with_warning(y).block_until_ready() for y in batch]
345 µs ± 7.57 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

We can even create a weird function that runs a lot faster using double negation.

>>> @jax.jit
>>> def weird_func(y):
...    return jnp.identity(y.shape[-1]) - jnp.matmul(- y, y)

>>> weird_func(jnp.ones((2, 100, 100))).shape
(2, 100, 100)
<no warning>

>>> batch = jnp.ones((10, 100, 100))
>>> %timeit weird_func(batch).block_until_ready()
277 µs ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Notebook for reproducing the above is here.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.24
jaxlib: 0.4.24
numpy:  1.26.4
python: 3.10.10 (main, Mar  3 2023, 16:31:35) [GCC 9.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1

tillahoffmann avatar Feb 20 '24 02:02 tillahoffmann

It turns out I cannot reproduce this on an M1 MacbookPro, but it is reproducible on GitHub actions. See https://github.com/tillahoffmann/google-jax-19885 and an example GitHub action run here.

tillahoffmann avatar Feb 23 '24 16:02 tillahoffmann

Hi @tillahoffmann It appears this issue has been resolved in the latest JAX versions. I ran the mentioned code on JAX version 0.4.26 using Colab on CPU, TPU, and GPU backends (both v0.4.26 and v0.4.27). It executed without any warnings.

Below is the output of the code when running on CPU:

>>> import jax
>>> from jax import numpy as jnp
>>> print(jax.__version__)

>>> @jax.jit
>>> def func_with_warning(y):
...    return jnp.identity(y.shape[-1]) + jnp.matmul(y, y)

>>> func_with_warning(jnp.ones((2, 100, 100))).shape

Output:

0.4.26
(2, 100, 100)

I've included a Gist here for your reference.

selamw1 avatar May 08 '24 00:05 selamw1

Thanks for following up!

jakevdp avatar May 08 '24 01:05 jakevdp

That's great, thank you @selamw1!

tillahoffmann avatar May 10 '24 17:05 tillahoffmann