jax icon indicating copy to clipboard operation
jax copied to clipboard

Einsum can be slower that matmul on CPU

Open romanngg opened this issue 5 years ago • 14 comments

An example where np.einsum is slower than manual matmul/transpositions. (#1966 works equally fast for me, but this example is consistently slower) on CPU and GPU.

https://colab.research.google.com/gist/romanngg/e63834765d00497e315455867a52eae1/einsum_is_slow.ipynb

import jax.numpy as np
import jax.random as random
from jax.api import jit

a = random.normal(random.PRNGKey(1), (100, 20, 20, 3))
b = random.normal(random.PRNGKey(2), (200, 20, 20, 3))

@jit
def matmul(a, b):
  return np.transpose(np.matmul(np.transpose(a, axes=(1, 2, 0, 3)), np.transpose(b, axes=(1, 2, 3, 0))), axes=(2, 3, 0, 1))

@jit
def einsum(a, b):
  return np.einsum('nxyc,mxyc->nmxy', a, b, optimize=True)

np.sum(np.abs(einsum(a, b) - matmul(a, b)))

%timeit einsum(a, b).block_until_ready()

%timeit matmul(a, b).block_until_ready()

Also note that if you run it on CPU, the difference between the method outputs becomes non-zero DeviceArray(0.01003271, dtype=float32) - not sure how concerning it is.

romanngg avatar Feb 04 '20 02:02 romanngg

FYI, I have revisited the example below on:

  1. CPU: einsum is slow AND wrong: https://colab.research.google.com/gist/romanngg/48fb8d4d3a3fb5da9be84d8d1fb862ad/einsum_is_wrong_and_slow_cpu.ipynb
  2. GPU: einsum is slow: https://colab.research.google.com/gist/romanngg/dd1e2adbda90749f140012f1b9342353/einsum_is_slow_gpu.ipynb
  3. TPU: einsum is OK! https://colab.research.google.com/gist/romanngg/635b467426bd9ead276cc6f9216ed03d/einsum_is_ok_tpu.ipynb

Will file bugs agains XLA:CPU and XLA:GPU!

romanngg avatar Jun 25 '20 14:06 romanngg

@romanngg

Curious about progress. Also, difference in CPU is quite small (0.01 after taking sum over all elements). That's imprecision but not an error

arogozhnikov avatar Aug 23 '20 21:08 arogozhnikov

Haven't heard anything back yet

romanngg avatar Aug 24 '20 04:08 romanngg

I think "wrong" is an overstatement here. In floating point arithmetic, two different ways of computing the same results are not guaranteed to exactly agree. That is particular true for heavily optimized routines such as matrix multiplication.

Note, for example, that NumPy's einsum gives an even more different result here:

>>> np.sum(np.abs(np.einsum('nxyc,mxyc->nmxy', a, b) - jnp.einsum('nxyc,mxyc->nmxy', a, b)))
0.3189874

If you look at the implementation of these two functions (matmul vs einsum), even though they are calculating the same thing (in principle) they are calculating it differently:

>>> jax.make_jaxpr(einsum)(a, b)
{ lambda  ; a b.
  let c = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b.
                                 let c = xla_call[ backend=None
                                                   call_jaxpr={ lambda  ; a b.
                                                                let c = dot_general[ dimension_numbers=(((3,), (3,)), ((1, 2), (1, 2)))
                                                                                     precision=None ] b a
                                                                    d = transpose[ permutation=(3, 2, 0, 1) ] c
                                                                in (d,) }
                                                   device=None
                                                   donated_invars=(False, False)
                                                   name=_einsum ] a b
                                 in (c,) }
                    device=None
                    donated_invars=(False, False)
                    name=einsum ] a b
  in (c,) }

>>> jax.make_jaxpr(matmul)(a, b)
{ lambda  ; a b.
  let c = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a b.
                                 let c = transpose[ permutation=(1, 2, 0, 3) ] a
                                     d = transpose[ permutation=(1, 2, 3, 0) ] b
                                     e = dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))
                                                      precision=None ] c d
                                     f = transpose[ permutation=(2, 3, 0, 1) ] e
                                 in (f,) }
                    device=None
                    donated_invars=(False, False)
                    name=matmul ] a b
  in (c,) }

XLA is usually pretty good about picking a good way to implement matrix multiplication, but it's not making the best choice here on CPU/GPU without your manual transposes. Those are definitely good opportunities for further improvement.

shoyer avatar Aug 24 '20 07:08 shoyer

To be clear, the issue here isn't that einsum itself is slow, which as you can see generates quite reasonable code. This is an indictment of XLA's DotGeneral (which again, usually does pretty well).

shoyer avatar Aug 24 '20 07:08 shoyer

Ran some benchmarks on several libraries, time in seconds on a Ryzen 3900X

size 64 128 256
dtype float32 float32 float32
lib numpy torch jax numpy torch jax numpy torch jax
reduction
ij,ijkl->kl 0.00 0.01 0.02 0.05 0.05 0.06 0.75 0.76 0.70
ji,ijkl->kl 0.00 0.00 0.02 0.05 0.05 0.06 0.75 0.76 0.70
jk,ijkl->il 0.00 0.01 0.03 0.05 0.20 0.26 0.77 2.75 2.78
kj,ijkl->il 0.01 0.01 0.04 0.16 0.20 0.22 2.33 2.80 2.87
ik,ijkl->jl 0.01 0.02 0.03 0.15 0.19 0.22 2.35 2.84 2.94
ki,ijkl->jl 0.01 0.01 0.04 0.15 0.20 0.22 2.34 2.78 3.07
li,ijkl->jk 0.01 0.01 0.04 0.29 0.23 0.26 4.61 10.14 9.66
il,ijkl->jk 0.00 0.02 0.04 0.06 0.22 0.28 0.82 9.98 9.86
lj,ijkl->ik 0.01 0.01 0.04 0.30 0.23 0.27 4.66 9.62 10.02
jl,ijkl->ik 0.00 0.02 0.04 0.06 0.22 0.25 0.90 10.07 10.08
lk,ijkl->ij 0.01 0.00 0.07 0.28 0.04 0.82 4.27 0.76 59.20
kl,ijkl->ij 0.00 0.00 0.06 0.06 0.05 0.83 0.90 0.76 62.52
np.__version__='1.21.5'
torch.__version__='1.10.2'
jax.__version__='0.2.21'
jaxlib.__version__='0.1.76'

randolf-scholz avatar Feb 14 '22 12:02 randolf-scholz

I just re-ran the code in the OP on an A100 GCP machine, and got this:

image

The jaxprs mainly differ in how the inputs are transposed:

image

Seems like we want an XLA:GPU improvement here... I'll bring it up in a meeting.

mattjj avatar Sep 14 '22 20:09 mattjj

Hi, are there any updates on this? I got here because I was getting very slow performance on my code using einsum.

LucasAlegre avatar Nov 10 '22 17:11 LucasAlegre

@LucasAlegre Please provide a reproduction (e.g., microbenchmark that demonstrates the problem).

hawkinsp avatar Nov 10 '22 17:11 hawkinsp

Since it has been almost 6 months since the reply, @mattjj, have three been any updates on this matter?

johnypark avatar Mar 04 '23 23:03 johnypark

I can no longer reproduce the GPU results at head. The CPU results still apply.

hawkinsp avatar Nov 06 '23 20:11 hawkinsp

I don't know whether this issue still matters in 2024, but I am sharing my results: Here x and y are 3D tensors:

image

image

qnixsynapse avatar Mar 17 '24 04:03 qnixsynapse

Note: unless you add .block_until_ready(), the timing numbers aren't measuring what you think they are measuring. See https://jax.readthedocs.io/en/latest/async_dispatch.html

hawkinsp avatar Mar 17 '24 13:03 hawkinsp

Updated!!

qnixsynapse avatar Mar 17 '24 16:03 qnixsynapse