jax
jax copied to clipboard
Einsum can be slower that matmul on CPU
An example where np.einsum is slower than manual matmul/transpositions. (#1966 works equally fast for me, but this example is consistently slower) on CPU and GPU.
https://colab.research.google.com/gist/romanngg/e63834765d00497e315455867a52eae1/einsum_is_slow.ipynb
import jax.numpy as np
import jax.random as random
from jax.api import jit
a = random.normal(random.PRNGKey(1), (100, 20, 20, 3))
b = random.normal(random.PRNGKey(2), (200, 20, 20, 3))
@jit
def matmul(a, b):
return np.transpose(np.matmul(np.transpose(a, axes=(1, 2, 0, 3)), np.transpose(b, axes=(1, 2, 3, 0))), axes=(2, 3, 0, 1))
@jit
def einsum(a, b):
return np.einsum('nxyc,mxyc->nmxy', a, b, optimize=True)
np.sum(np.abs(einsum(a, b) - matmul(a, b)))
%timeit einsum(a, b).block_until_ready()
%timeit matmul(a, b).block_until_ready()
Also note that if you run it on CPU, the difference between the method outputs becomes non-zero
DeviceArray(0.01003271, dtype=float32) - not sure how concerning it is.
FYI, I have revisited the example below on:
- CPU: einsum is slow AND wrong: https://colab.research.google.com/gist/romanngg/48fb8d4d3a3fb5da9be84d8d1fb862ad/einsum_is_wrong_and_slow_cpu.ipynb
- GPU: einsum is slow: https://colab.research.google.com/gist/romanngg/dd1e2adbda90749f140012f1b9342353/einsum_is_slow_gpu.ipynb
- TPU: einsum is OK! https://colab.research.google.com/gist/romanngg/635b467426bd9ead276cc6f9216ed03d/einsum_is_ok_tpu.ipynb
Will file bugs agains XLA:CPU and XLA:GPU!
@romanngg
Curious about progress. Also, difference in CPU is quite small (0.01 after taking sum over all elements). That's imprecision but not an error
Haven't heard anything back yet
I think "wrong" is an overstatement here. In floating point arithmetic, two different ways of computing the same results are not guaranteed to exactly agree. That is particular true for heavily optimized routines such as matrix multiplication.
Note, for example, that NumPy's einsum gives an even more different result here:
>>> np.sum(np.abs(np.einsum('nxyc,mxyc->nmxy', a, b) - jnp.einsum('nxyc,mxyc->nmxy', a, b)))
0.3189874
If you look at the implementation of these two functions (matmul vs einsum), even though they are calculating the same thing (in principle) they are calculating it differently:
>>> jax.make_jaxpr(einsum)(a, b)
{ lambda ; a b.
let c = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = dot_general[ dimension_numbers=(((3,), (3,)), ((1, 2), (1, 2)))
precision=None ] b a
d = transpose[ permutation=(3, 2, 0, 1) ] c
in (d,) }
device=None
donated_invars=(False, False)
name=_einsum ] a b
in (c,) }
device=None
donated_invars=(False, False)
name=einsum ] a b
in (c,) }
>>> jax.make_jaxpr(matmul)(a, b)
{ lambda ; a b.
let c = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = transpose[ permutation=(1, 2, 0, 3) ] a
d = transpose[ permutation=(1, 2, 3, 0) ] b
e = dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))
precision=None ] c d
f = transpose[ permutation=(2, 3, 0, 1) ] e
in (f,) }
device=None
donated_invars=(False, False)
name=matmul ] a b
in (c,) }
XLA is usually pretty good about picking a good way to implement matrix multiplication, but it's not making the best choice here on CPU/GPU without your manual transposes. Those are definitely good opportunities for further improvement.
To be clear, the issue here isn't that einsum itself is slow, which as you can see generates quite reasonable code. This is an indictment of XLA's DotGeneral (which again, usually does pretty well).
Ran some benchmarks on several libraries, time in seconds on a Ryzen 3900X
| size | 64 | 128 | 256 | ||||||
|---|---|---|---|---|---|---|---|---|---|
| dtype | float32 | float32 | float32 | ||||||
| lib | numpy | torch | jax | numpy | torch | jax | numpy | torch | jax |
| reduction | |||||||||
| ij,ijkl->kl | 0.00 | 0.01 | 0.02 | 0.05 | 0.05 | 0.06 | 0.75 | 0.76 | 0.70 |
| ji,ijkl->kl | 0.00 | 0.00 | 0.02 | 0.05 | 0.05 | 0.06 | 0.75 | 0.76 | 0.70 |
| jk,ijkl->il | 0.00 | 0.01 | 0.03 | 0.05 | 0.20 | 0.26 | 0.77 | 2.75 | 2.78 |
| kj,ijkl->il | 0.01 | 0.01 | 0.04 | 0.16 | 0.20 | 0.22 | 2.33 | 2.80 | 2.87 |
| ik,ijkl->jl | 0.01 | 0.02 | 0.03 | 0.15 | 0.19 | 0.22 | 2.35 | 2.84 | 2.94 |
| ki,ijkl->jl | 0.01 | 0.01 | 0.04 | 0.15 | 0.20 | 0.22 | 2.34 | 2.78 | 3.07 |
| li,ijkl->jk | 0.01 | 0.01 | 0.04 | 0.29 | 0.23 | 0.26 | 4.61 | 10.14 | 9.66 |
| il,ijkl->jk | 0.00 | 0.02 | 0.04 | 0.06 | 0.22 | 0.28 | 0.82 | 9.98 | 9.86 |
| lj,ijkl->ik | 0.01 | 0.01 | 0.04 | 0.30 | 0.23 | 0.27 | 4.66 | 9.62 | 10.02 |
| jl,ijkl->ik | 0.00 | 0.02 | 0.04 | 0.06 | 0.22 | 0.25 | 0.90 | 10.07 | 10.08 |
| lk,ijkl->ij | 0.01 | 0.00 | 0.07 | 0.28 | 0.04 | 0.82 | 4.27 | 0.76 | 59.20 |
| kl,ijkl->ij | 0.00 | 0.00 | 0.06 | 0.06 | 0.05 | 0.83 | 0.90 | 0.76 | 62.52 |
np.__version__='1.21.5'
torch.__version__='1.10.2'
jax.__version__='0.2.21'
jaxlib.__version__='0.1.76'
I just re-ran the code in the OP on an A100 GCP machine, and got this:
The jaxprs mainly differ in how the inputs are transposed:
Seems like we want an XLA:GPU improvement here... I'll bring it up in a meeting.
Hi, are there any updates on this? I got here because I was getting very slow performance on my code using einsum.
@LucasAlegre Please provide a reproduction (e.g., microbenchmark that demonstrates the problem).
Since it has been almost 6 months since the reply, @mattjj, have three been any updates on this matter?
I can no longer reproduce the GPU results at head. The CPU results still apply.
I don't know whether this issue still matters in 2024, but I am sharing my results: Here x and y are 3D tensors:
Note: unless you add .block_until_ready(), the timing numbers aren't measuring what you think they are measuring. See https://jax.readthedocs.io/en/latest/async_dispatch.html
Updated!!