jax
jax copied to clipboard
Bfloat16 matrix multiply is slower than FP16 on A100
Using an A100 Colab:
import jax
import jax.numpy as jnp
print(jnp.array(1).device().device_kind)
@jax.jit
def f(x, y):
return jnp.einsum('bqc,bkc->bqk', x, y)
x_bfloat = jnp.ones((384 * 4, 384, 16), dtype=jnp.bfloat16)
x_float = jnp.ones((384 * 4, 384, 16), dtype=jnp.float16)
# Warmup
_ = f(x_bfloat, x_bfloat)
_ = f(x_float, x_float)
%timeit f(x_float, x_float).block_until_ready()
%timeit f(x_bfloat, x_bfloat).block_until_ready()
Gives
A100-SXM4-40GB
1000 loops, best of 5: 651 µs per loop
1000 loops, best of 5: 912 µs per loop
I would expect these to be the same speed.
Which version of cublas/cudnn and/or cuda did you used? Was it a PCI-E A100 of an HBM A100?
I'm able to reproduce this with a recent CUDA. I also saw that you use an HBM A100.
NVIDIA A100-SXM4-40GB
566 µs ± 125 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
731 µs ± 69.8 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
This will be fixed in a future cublas release.
@nouiz That's great to hear, I'm running into the same issue. About when do you think the fix will be available?
@nouiz: Any updates about which cuBLAS update will contain the fix?
This is still being worked on. I can't say for sure which version will have the fix, but we try to have it fixed "soon".
@nouiz: did this end up being fixed by CUDA 12?
cublas from CUDA 12 already as this fix. We do not have yet a public cudnn for CUDA 12. So JAX can't be easily build with CUDA 12 yet. This is coming soon.
The new cudnn 8.8 release is done. So all should be good now for a full CUDA 12 stack. https://docs.nvidia.com/deeplearning/cudnn/release-notes/index.html#rel-880
@nouiz can we close this issue?
yes.