jax icon indicating copy to clipboard operation
jax copied to clipboard

`jax.lax.dot` with `preferred_element_type` gives wrong result

Open ChenAo-Phys opened this issue 3 years ago • 3 comments

The bug can be reproduced by the following code

import jax.numpy as jnp
from jax.random import PRNGKey, normal
from jax.lax import dot
from jax.config import config
config.update("jax_enable_x64", True)

A = normal(PRNGKey(0), (2, 2), dtype=jnp.float32)
print(A @ A.T)
#[[3.868021   1.0210072 ]
# [1.0210072  0.40157342]]

print(dot(A, A.T, preferred_element_type=jnp.float32))
#[[3.868021   1.0210072 ]
# [1.0210072  0.40157342]]

print(dot(A, A.T, preferred_element_type=jnp.float64))
#[[-4.71044037e-10  2.50204452e-08]
# [ 0.00000000e+00  0.00000000e+00]]

It happens on RTX 3090 with cuda 11.6, cudatoolkit 11.6.0, cudnn 8.4.1.50, python 3.8.13, jax 0.3.13 or python 3.9.7, jax 0.3.15 (error happens in both cases)

However, this problem doesn't happen on Colab P100 GPU. I haven't tested on other types of GPU.

ChenAo-Phys avatar Jul 30 '22 08:07 ChenAo-Phys

Thanks for the report - I suspect this may be related to a known bug in dot_general with preferred_element_type on GPU. Note that currently, we skip tests that exercise this because of a somewhat different issue: https://github.com/google/jax/blob/75d69725c384cd08b0ca55fa9e315fb9bb5830e4/tests/lax_test.py#L1129-L1131

jakevdp avatar Aug 01 '22 22:08 jakevdp

I did some further tests in these days. This bug doesn't happen on CPU, but still happens on A100 GPU.

ChenAo-Phys avatar Aug 03 '22 16:08 ChenAo-Phys

(Removed George because he's OOO, and this looks like an XLA-related bug.)

mattjj avatar Aug 11 '22 17:08 mattjj

Tracked internally in b/253051564. Should we use the openxla repo btw for such bugs?

cheshire avatar Oct 11 '22 18:10 cheshire

Looking into this. It's an issue XLA calling cublas with an "unsupported" type combination. Pushing a fix soon. @AyanmoI have you done any work here?

SandSnip3r avatar Oct 20 '22 21:10 SandSnip3r

The wrong output was fixed last year, we just never updated the bug. Oops.

These days, we give the (admittedly slightly cryptic) error: Unexpected GEMM dtype: f32 f32 f64 which is infinitely better than a wrong output. I think that's about as good as we can do, and in this case you should just use a f64xf64->f64 matmul anyway: that will be the most performant option.

Hope that helps.

hawkinsp avatar Nov 17 '23 21:11 hawkinsp