jax
jax copied to clipboard
`jax.lax.dot` with `preferred_element_type` gives wrong result
The bug can be reproduced by the following code
import jax.numpy as jnp
from jax.random import PRNGKey, normal
from jax.lax import dot
from jax.config import config
config.update("jax_enable_x64", True)
A = normal(PRNGKey(0), (2, 2), dtype=jnp.float32)
print(A @ A.T)
#[[3.868021 1.0210072 ]
# [1.0210072 0.40157342]]
print(dot(A, A.T, preferred_element_type=jnp.float32))
#[[3.868021 1.0210072 ]
# [1.0210072 0.40157342]]
print(dot(A, A.T, preferred_element_type=jnp.float64))
#[[-4.71044037e-10 2.50204452e-08]
# [ 0.00000000e+00 0.00000000e+00]]
It happens on RTX 3090 with cuda 11.6, cudatoolkit 11.6.0, cudnn 8.4.1.50, python 3.8.13, jax 0.3.13 or python 3.9.7, jax 0.3.15 (error happens in both cases)
However, this problem doesn't happen on Colab P100 GPU. I haven't tested on other types of GPU.
Thanks for the report - I suspect this may be related to a known bug in dot_general with preferred_element_type on GPU. Note that currently, we skip tests that exercise this because of a somewhat different issue: https://github.com/google/jax/blob/75d69725c384cd08b0ca55fa9e315fb9bb5830e4/tests/lax_test.py#L1129-L1131
I did some further tests in these days. This bug doesn't happen on CPU, but still happens on A100 GPU.
(Removed George because he's OOO, and this looks like an XLA-related bug.)
Tracked internally in b/253051564. Should we use the openxla repo btw for such bugs?
Looking into this. It's an issue XLA calling cublas with an "unsupported" type combination. Pushing a fix soon. @AyanmoI have you done any work here?
The wrong output was fixed last year, we just never updated the bug. Oops.
These days, we give the (admittedly slightly cryptic) error: Unexpected GEMM dtype: f32 f32 f64 which is infinitely better than a wrong output. I think that's about as good as we can do, and in this case you should just use a f64xf64->f64 matmul anyway: that will be the most performant option.
Hope that helps.