jax icon indicating copy to clipboard operation
jax copied to clipboard

GPU not partitioning RNG with jax_threefry_partitionable=True (vmap/scan layers)

Open dlwh opened this issue 4 months ago • 4 comments

Description

Mostly minimized? This shouldn't OOM on the 4 a6000's I'm using (and the unminimized version works fine on TPU)

import jax
import jax.random as jrandom

from jax.sharding import PartitionSpec as P, Mesh, NamedSharding

from jax.lax import with_sharding_constraint

UP_PSPEC = P("data", None)
DOWN_PSPEC = P(None, "data")
VMAP_UP_PSPEC = P(None, "data", None)
VMAP_DOWN_PSPEC = P(None, "data", None)

MLP = 14336
EMBED = 4096
LAYERS = 48


def linear_init(dims, key, pspec):
    return with_sharding_constraint(jrandom.normal(key, dims), pspec) * 0.02


def init_module(key):
    k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3)
    gate_proj = linear_init((MLP, EMBED), k_fc, UP_PSPEC)

    return gate_proj


def init_stack(key):
    return jax.vmap(init_module, spmd_axis_name=None)(jrandom.split(key, LAYERS))


DEFAULT_JAX_CONFIG = {
    "jax_threefry_partitionable": True,
}

def main():
    for k, v in DEFAULT_JAX_CONFIG.items():
        jax.config.update(k, v)

    mesh = Mesh(jax.devices("gpu"), ("data",))

    with mesh:
        key = jrandom.PRNGKey(0)
        jit_init_stack = jax.jit(init_stack,
                                 out_shardings=(NamedSharding(mesh, VMAP_UP_PSPEC)))

        jit_init_stack(key)

main()
Peak buffers:
	Buffer 1:
		Size: 10.50GiB
		Operator: op_name="jit(init_stack)/jit(main)/vmap(jit(_normal))/jit(_normal_real)/jit(_uniform)/threefry2x32" source_file="/sailhome/dlwh/vmap_rand_oom.py" source_line=19
		XLA Label: custom-call
		Shape: u32[48,14336,4096]
		==========================

	Buffer 2:
		Size: 10.50GiB
		Operator: op_name="jit(init_stack)/jit(main)/vmap(jit(_normal))/jit(_normal_real)/jit(_uniform)/threefry2x32" source_file="/sailhome/dlwh/vmap_rand_oom.py" source_line=19
		XLA Label: custom-call
		Shape: u32[48,14336,4096]
		==========================

	Buffer 3:
		Size: 10.50GiB
		Operator: op_name="jit(init_stack)/jit(main)/vmap(jit(_normal))/jit(_normal_real)/jit(_uniform)/threefry2x32" source_file="/sailhome/dlwh/vmap_rand_oom.py" source_line=19
		XLA Label: fusion
		Shape: u32[48,14336,4096]
		==========================

cc @froystig @mattjj

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.23
jaxlib: 0.4.23
numpy:  1.26.2
python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
jax.devices (4 total, 4 local): [cuda(id=0) cuda(id=1) cuda(id=2) cuda(id=3)]
process_count: 1

$ nvidia-smi
Mon Feb 19 09:37:48 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX A6000               On  | 00000000:01:00.0 Off |                  Off |
| 30%   32C    P2              24W / 300W |    270MiB / 49140MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000               On  | 00000000:25:00.0 Off |                  Off |
| 30%   31C    P2              20W / 300W |    270MiB / 49140MiB |      2%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA RTX A6000               On  | 00000000:41:00.0 Off |                  Off |
| 30%   33C    P2              31W / 300W |    270MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA RTX A6000               On  | 00000000:61:00.0 Off |                  Off |
| 30%   30C    P2              30W / 300W |    270MiB / 49140MiB |      2%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   1699541      C   python                                      262MiB |
|    1   N/A  N/A   1699541      C   python                                      262MiB |
|    2   N/A  N/A   1699541      C   python                                      262MiB |
|    3   N/A  N/A   1699541      C   python                                      262MiB |
+---------------------------------------------------------------------------------------+

dlwh avatar Feb 19 '24 17:02 dlwh

(changing scan layers to a list comprehension works around)

dlwh avatar Feb 19 '24 17:02 dlwh

ruled out vmap as a culprit. Same failure:

import jax
import jax.random as jrandom

from jax.sharding import PartitionSpec as P, Mesh, NamedSharding

from jax.lax import with_sharding_constraint

UP_PSPEC = P("data", None)
DOWN_PSPEC = P(None, "data")
VMAP_UP_PSPEC = P(None, "data", None)
VMAP_DOWN_PSPEC = P(None, "data", None)

MLP = 14336
EMBED = 4096
LAYERS = 48


def linear_init(dims, key, pspec):
    return with_sharding_constraint(jrandom.normal(key, dims), pspec) * 0.02


def init_module(key):
    k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3)
    gate_proj = linear_init((LAYERS, MLP, EMBED), k_fc, VMAP_UP_PSPEC)

    return gate_proj


def init_stack(key):
    #return jax.vmap(init_module, spmd_axis_name=None)(jrandom.split(key, LAYERS))
    #return [init_module(k) for k in jrandom.split(key, LAYERS)]
    return init_module(key)


DEFAULT_JAX_CONFIG = {
    "jax_threefry_partitionable": True,
}

def main():
    for k, v in DEFAULT_JAX_CONFIG.items():
        jax.config.update(k, v)

    mesh = Mesh(jax.devices("gpu"), ("data",))

    with mesh:
        key = jrandom.PRNGKey(0)
        jit_init_stack = jax.jit(init_stack)

        jit_init_stack(key)

main()

dlwh avatar Feb 19 '24 19:02 dlwh

I think it might be the same issue as https://github.com/google/jax/issues/19893

hr0nix avatar Feb 20 '24 18:02 hr0nix

seems like it

dlwh avatar Feb 20 '24 18:02 dlwh