jax
jax copied to clipboard
Large numerical error when using vmap and bfloat16/tensorfloat32 matmul precision, only on A100 GPU
Description
I've found that when vmapping a simple MLP forward function, the numerical accuracy seems to decrease significantly when using either bfloat16 or tensorfloat32 matmul precision. This issue seems to occur only on the A100 GPU on Colab, but not on CPU/V100/T4 GPUs. This notebook should demonstrate the issue: https://colab.research.google.com/gist/rhacking/c4a285ee4b6931c3123c71ae8cd9e490/matmul_precision_error.ipynb. I suppose some level of decrease of precision should be expected when vmapping/lowering matmul precision, but this is larger than I would expect (and I would expect the behavior to be the same on the T4 as well).
Code to reproduce:
import jax
from jax import vmap, jit
from jax.nn import initializers as inits
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
def gen_mlp(layers, rng_key):
init = inits.he_uniform(-1, -2)
weight_key, gamma_key, mu_key = jax.random.split(rng_key, 3)
weight_keys = jax.random.split(weight_key, len(layers)-1)
weights = [init(k, (l2, l1)) for k, l1, l2 in zip(weight_keys, layers[:-1], layers[1:])]
biases = [jnp.zeros(l) for l in layers[1:]]
gamma_keys = jax.random.split(gamma_key, len(layers)-2)
gammas = [jax.random.normal(k, shape=(l, ))*0.5 for k, l in zip(gamma_keys, layers[1:-1])]
mu_keys = jax.random.split(mu_key, len(layers)-2)
mus = [jax.random.uniform(k, shape=(l, ), minval=-jnp.pi, maxval=jnp.pi) for k, l in zip(mu_keys, layers[1:-1])]
def forward(x, params):
weights, biases, gammas, mus = params
x_in = x
for W, b, gamma, mu in zip(weights[:-1], biases[:-1], gammas, mus):
x = W @ x + b
x = jnp.exp(-gamma**2 * jnp.square(x-mu)) * jnp.sin(x*16.0)
x = weights[-1] @ x + biases[-1]
return x.mean()
return forward, (weights, biases, gammas, mus)
forward, params = gen_mlp([2, 48, 48, 48, 1], jax.random.PRNGKey(42))
xx, yy = np.meshgrid(np.linspace(-1, 1, 64), np.linspace(-1, 1, 64))
xy = np.stack([xx.ravel(), yy.ravel()], axis=1)
data = []
true = forward(xy[0], params)
fn = jax.jit(vmap(forward, (0, None)))
for n in range(1, 10):
with jax.default_matmul_precision('bfloat16'):
data.append({'n': n, 'val': float(fn(xy[:n], params).ravel()[0]), 'precision_type': 'bfloat16'})
with jax.default_matmul_precision('tensorfloat32'):
data.append({'n': n, 'val': float(fn(xy[:n], params).ravel()[0]), 'precision_type': 'tensorfloat32'})
with jax.default_matmul_precision('float32'):
data.append({'n': n, 'val': float(fn(xy[:n], params).ravel()[0]), 'precision_type': 'float32'})
df = pd.DataFrame(data)
df['err'] = np.abs(df['val'] - true)
sns.lineplot(data=df, x='n', y='val', hue='precision_type')
plt.show()
sns.lineplot(data=df, x='n', y='err', hue='precision_type')
plt.show()
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.23
jaxlib: 0.4.23
numpy: 1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
$ nvidia-smi
Mon Feb 26 09:30:56 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-40GB Off | 00000000:00:04.0 Off | 0 |
| N/A 31C P0 50W / 400W | 425MiB / 40960MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+