jax icon indicating copy to clipboard operation
jax copied to clipboard

Large numerical error when using vmap and bfloat16/tensorfloat32 matmul precision, only on A100 GPU

Open rhacking opened this issue 1 year ago • 0 comments

Description

I've found that when vmapping a simple MLP forward function, the numerical accuracy seems to decrease significantly when using either bfloat16 or tensorfloat32 matmul precision. This issue seems to occur only on the A100 GPU on Colab, but not on CPU/V100/T4 GPUs. This notebook should demonstrate the issue: https://colab.research.google.com/gist/rhacking/c4a285ee4b6931c3123c71ae8cd9e490/matmul_precision_error.ipynb. I suppose some level of decrease of precision should be expected when vmapping/lowering matmul precision, but this is larger than I would expect (and I would expect the behavior to be the same on the T4 as well).

Code to reproduce:

import jax
from jax import vmap, jit
from jax.nn import initializers as inits
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

def gen_mlp(layers, rng_key):
  init = inits.he_uniform(-1, -2)
  weight_key, gamma_key, mu_key = jax.random.split(rng_key, 3)
  
  weight_keys = jax.random.split(weight_key, len(layers)-1)
  weights = [init(k, (l2, l1)) for k, l1, l2 in zip(weight_keys, layers[:-1], layers[1:])]
  biases = [jnp.zeros(l) for l in layers[1:]]

  gamma_keys = jax.random.split(gamma_key, len(layers)-2)
  gammas = [jax.random.normal(k, shape=(l, ))*0.5 for k, l in zip(gamma_keys, layers[1:-1])]

  mu_keys = jax.random.split(mu_key, len(layers)-2)
  mus = [jax.random.uniform(k, shape=(l, ), minval=-jnp.pi, maxval=jnp.pi) for k, l in zip(mu_keys, layers[1:-1])]

  def forward(x, params):
    weights, biases, gammas, mus = params
    x_in = x
    for W, b, gamma, mu in zip(weights[:-1], biases[:-1], gammas, mus):
      x = W @ x + b
      x = jnp.exp(-gamma**2 * jnp.square(x-mu)) * jnp.sin(x*16.0)
    x = weights[-1] @ x + biases[-1]
    return x.mean()

  return forward, (weights, biases, gammas, mus)

forward, params = gen_mlp([2, 48, 48, 48, 1], jax.random.PRNGKey(42))

xx, yy = np.meshgrid(np.linspace(-1, 1, 64), np.linspace(-1, 1, 64))
xy = np.stack([xx.ravel(), yy.ravel()], axis=1)

data = []
true = forward(xy[0], params)
fn = jax.jit(vmap(forward, (0, None)))
for n in range(1, 10):
  with jax.default_matmul_precision('bfloat16'):
    data.append({'n': n, 'val': float(fn(xy[:n], params).ravel()[0]), 'precision_type': 'bfloat16'})
  with jax.default_matmul_precision('tensorfloat32'):
    data.append({'n': n, 'val': float(fn(xy[:n], params).ravel()[0]), 'precision_type': 'tensorfloat32'})
  with jax.default_matmul_precision('float32'):
    data.append({'n': n, 'val': float(fn(xy[:n], params).ravel()[0]), 'precision_type': 'float32'})

df = pd.DataFrame(data)
df['err'] = np.abs(df['val'] - true)
sns.lineplot(data=df, x='n', y='val', hue='precision_type')
plt.show()
sns.lineplot(data=df, x='n', y='err', hue='precision_type')
plt.show()

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.23
jaxlib: 0.4.23
numpy:  1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1

$ nvidia-smi
Mon Feb 26 09:30:56 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              50W / 400W |    425MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

rhacking avatar Feb 26 '24 09:02 rhacking