jax icon indicating copy to clipboard operation
jax copied to clipboard

Bfloat16 matrix multiply is slower than FP16 on A100

Open sbodenstein opened this issue 2 years ago • 4 comments

Using an A100 Colab:

import jax
import jax.numpy as jnp

print(jnp.array(1).device().device_kind)

@jax.jit
def f(x, y):
  return jnp.einsum('bqc,bkc->bqk', x, y)

x_bfloat = jnp.ones((384 * 4, 384, 16), dtype=jnp.bfloat16)
x_float = jnp.ones((384 * 4, 384, 16), dtype=jnp.float16)

# Warmup
_ = f(x_bfloat, x_bfloat)
_ = f(x_float, x_float)

%timeit f(x_float, x_float).block_until_ready()
%timeit f(x_bfloat, x_bfloat).block_until_ready()

Gives

A100-SXM4-40GB
1000 loops, best of 5: 651 µs per loop
1000 loops, best of 5: 912 µs per loop

I would expect these to be the same speed.

sbodenstein avatar Aug 02 '22 14:08 sbodenstein

Which version of cublas/cudnn and/or cuda did you used? Was it a PCI-E A100 of an HBM A100?

nouiz avatar Sep 01 '22 14:09 nouiz

I'm able to reproduce this with a recent CUDA. I also saw that you use an HBM A100.

NVIDIA A100-SXM4-40GB
566 µs ± 125 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
731 µs ± 69.8 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

nouiz avatar Sep 01 '22 14:09 nouiz

This will be fixed in a future cublas release.

nouiz avatar Sep 09 '22 13:09 nouiz

@nouiz That's great to hear, I'm running into the same issue. About when do you think the fix will be available?

danijar avatar Sep 17 '22 11:09 danijar

@nouiz: Any updates about which cuBLAS update will contain the fix?

sbodenstein avatar Oct 19 '22 13:10 sbodenstein

This is still being worked on. I can't say for sure which version will have the fix, but we try to have it fixed "soon".

nouiz avatar Oct 19 '22 14:10 nouiz

@nouiz: did this end up being fixed by CUDA 12?

sbodenstein avatar Feb 03 '23 14:02 sbodenstein

cublas from CUDA 12 already as this fix. We do not have yet a public cudnn for CUDA 12. So JAX can't be easily build with CUDA 12 yet. This is coming soon.

nouiz avatar Feb 08 '23 17:02 nouiz

The new cudnn 8.8 release is done. So all should be good now for a full CUDA 12 stack. https://docs.nvidia.com/deeplearning/cudnn/release-notes/index.html#rel-880

nouiz avatar Feb 09 '23 19:02 nouiz

@nouiz can we close this issue?

mattjj avatar Oct 31 '23 03:10 mattjj

yes.

nouiz avatar Oct 31 '23 15:10 nouiz