Stacking layers with `vmap` and forwarding with `scan` result in loss of precision in XLA backend
I followed the instructions in the tutorial#scan-over-layers to build a network with multiple layers with nnx.vmap and to forward with nnx.scan. However, doing so reults in loss of precision in XLA backend.
System information
- OS Platform: Ubuntu 20.04
- Flax, jax, jaxlib versions:
flax 0.10.4 pypi_0 pypi
jax 0.5.2 pypi_0 pypi
jax-cuda12-pjrt 0.5.1 pypi_0 pypi
jax-cuda12-plugin 0.5.1 pypi_0 pypi
jaxlib 0.5.1 pypi_0 pypi
optax 0.2.4 pypi_0 pypi
orbax-checkpoint 0.11.8 pypi_0 pypi
- Python version: 3.11.11
- GPU: RTX 3090, 25.31G and NVIDIA A100 80GB PCIe
Problem you have encountered:
This is a minimal example to reproduce the error.
import os
import jax
from flax import nnx
# os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
BATCH_SIZE = 2
SEQ_LEN = 2
FEATURES = 16
class AttentionLayer(nnx.Module):
def __init__(self, d_model, rngs):
self.attention = nnx.MultiHeadAttention(
num_heads=8,
in_features=d_model,
rngs=rngs
)
self.linear1 = nnx.Linear(in_features=d_model, out_features=d_model, rngs=rngs)
def __call__(self, x):
x = self.attention(x, decode=False)
x = self.linear1(x)
return x
def foo(x, layer_keys):
@nnx.vmap(in_axes=0, out_axes=0)
def create_layer(key):
layer_rngs = nnx.Rngs(key)
return AttentionLayer(FEATURES, layer_rngs)
model = create_layer(layer_keys)
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def apply_layers(layer, x):
return layer(x)
return apply_layers(model, x)
def bar(x, layer_keys):
layers = [AttentionLayer(FEATURES, nnx.Rngs(key)) for key in layer_keys]
for layer in layers:
x = layer(x)
return x
key = jax.random.PRNGKey(0)
layer_keys = jax.random.split(key, 2) # 2 layers
x = jax.random.normal(jax.random.PRNGKey(0), (BATCH_SIZE, SEQ_LEN, FEATURES))
foo(x, layer_keys) # errors
bar(x, layer_keys) # works
Executing function foo results in the following error.
Error on the 3090 machine:
E0316 21:11:08.907722 2675508 buffer_comparator.cc:156] Difference at 6: 8.18504, expected 7.16491
E0316 21:11:08.907797 2675508 buffer_comparator.cc:156] Difference at 7: 10.2058, expected 8.91315
E0316 21:11:08.907806 2675508 buffer_comparator.cc:156] Difference at 8: 8.30671, expected 6.65029
E0316 21:11:08.907811 2675508 buffer_comparator.cc:156] Difference at 9: 9.57833, expected 8.51971
E0316 21:11:08.907816 2675508 buffer_comparator.cc:156] Difference at 11: 12.3298, expected 10.7088
E0316 21:11:08.907827 2675508 buffer_comparator.cc:156] Difference at 15: 6.00732, expected 5.25697
E0316 21:11:08.907832 2675508 buffer_comparator.cc:156] Difference at 22: 8.97186, expected 7.9519
E0316 21:11:08.907838 2675508 buffer_comparator.cc:156] Difference at 24: 9.59525, expected 7.93915
E0316 21:11:08.907845 2675508 buffer_comparator.cc:156] Difference at 27: 13.3396, expected 11.7196
E0316 21:11:08.907852 2675508 buffer_comparator.cc:156] Difference at 38: 8.77498, expected 7.75514
2025-03-16 21:11:08.907867: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Error on the A100 machine:
E0317 05:01:29.217242 712317 buffer_comparator.cc:156] Difference at 6: 8.18504, expected 7.16615
E0317 05:01:29.217324 712317 buffer_comparator.cc:156] Difference at 7: 10.2058, expected 8.91452
E0317 05:01:29.217329 712317 buffer_comparator.cc:156] Difference at 8: 8.30671, expected 6.651
E0317 05:01:29.217332 712317 buffer_comparator.cc:156] Difference at 9: 9.57833, expected 8.51998
E0317 05:01:29.217335 712317 buffer_comparator.cc:156] Difference at 11: 12.3298, expected 10.7096
E0317 05:01:29.217339 712317 buffer_comparator.cc:156] Difference at 15: 6.00732, expected 5.25718
E0317 05:01:29.217342 712317 buffer_comparator.cc:156] Difference at 22: 8.97186, expected 7.95259
E0317 05:01:29.217345 712317 buffer_comparator.cc:156] Difference at 24: 9.59525, expected 7.9386
E0317 05:01:29.217348 712317 buffer_comparator.cc:156] Difference at 27: 13.3396, expected 11.7191
E0317 05:01:29.217351 712317 buffer_comparator.cc:156] Difference at 38: 8.77498, expected 7.75621
2025-03-17 05:01:29.217365: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
The function bar should be equivalent to foo, but works well without any errors.
Disabling the xla autotune by os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0" prevents the error. So it might be something relating to the autotune mechanism with vmap and scan.
CoLab reproducibility
I cannot reproduce it on colab. When running this code on colab with CPU, the result of these two functions are slightly different, while when running this code on colab with GPU T4, the calculated results are identical. code:
x1 = foo(x, layer_keys) # errors
x2 = bar(x, layer_keys) # works
print(x1-x2)
output on CPU:
[[[-1.1920929e-07 1.1920929e-07 1.1920929e-07 2.9802322e-08
1.1920929e-07 1.1920929e-07 0.0000000e+00 0.0000000e+00
-1.1920929e-07 2.9802322e-07 -1.4901161e-08 -1.3411045e-07
-5.9604645e-08 -1.1920929e-07 -1.1920929e-07 1.1920929e-07]
[ 2.9802322e-08 -5.9604645e-08 -1.1920929e-07 2.0861626e-07
-2.3841858e-07 -1.1920929e-07 -1.7881393e-07 0.0000000e+00
-1.1920929e-07 1.7881393e-07 5.9604645e-08 2.9802322e-08
-4.4703484e-08 0.0000000e+00 -1.1920929e-07 4.7683716e-07]]
[[ 1.7695129e-07 -6.7055225e-08 8.6612999e-08 -1.2665987e-07
1.1920929e-07 5.9604645e-08 -7.4505806e-09 1.1548400e-07
1.4901161e-07 0.0000000e+00 1.1920929e-07 7.4505806e-08
2.9802322e-08 5.9604645e-08 -1.1920929e-07 -1.7881393e-07]
[ 2.2351742e-08 2.9802322e-08 9.6857548e-08 -1.0058284e-07
0.0000000e+00 0.0000000e+00 -1.4901161e-08 1.4901161e-08
1.4901161e-07 3.7252903e-08 8.9406967e-08 -5.9604645e-08
0.0000000e+00 1.7881393e-07 -1.7881393e-07 -1.7881393e-07]]]
output on T4:
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]
Hi @HeavyCrab, this sounds like a JAX issue, can you please report this to the JAX repo? Doesn't seem there's a ton we can do from the Flax side here. Sorry for the inconvenience.
@cgarciae OK, I have reported this to the JAX repo.
We can close this issue as unrelated to Flax.