jax icon indicating copy to clipboard operation
jax copied to clipboard

jax_threefry_partitionable + rematerialization doesn't seem to be working together in distributed training

Open hr0nix opened this issue 8 months ago • 13 comments

Description

I have a transformer model where each transformer block is rematerialized. The model is distributed over multiple devices using jit. Each transformer block has dropout enabled.

To prevent rng implementation from inserting synchronization operations I'm also enabling jax_threefry_partitionable as suggested in the doc.

Problem is, jax_threefry_partitionable doesn't seem to play nicely with rematerialization. As soon as I enable dropout, I get GPU OOM because JAX decides to preserve huge arrays containing rng key per activation tensor component for each transformer block, despite them being rematerialized. It should be possible for jax to reconstruct this key array from a single key during rematerialization, but it doesn't seem to do that.

I'm happy to provide a repoduction if you can confirm that this is unexpected behavior. If not, can you please suggest a workaround? Currently it doesn't seem possible to efficiently train large models with dropout.

A relevant discussion with OOM error message example here: https://github.com/google/flax/discussions/3090

What jax/jaxlib version are you using?

0.4.14

Which accelerator(s) are you using?

GPU

Additional system info

python3.10

NVIDIA GPU info

Fri Oct  6 14:22:06 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8A:00.0 Off |                    0 |
| N/A   35C    P0              72W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:8B:00.0 Off |                    0 |
| N/A   31C    P0              71W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:8C:00.0 Off |                    0 |
| N/A   31C    P0              72W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:8D:00.0 Off |                    0 |
| N/A   35C    P0              74W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:9C:00.0 Off |                    0 |
| N/A   36C    P0              76W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:9D:00.0 Off |                    0 |
| N/A   31C    P0              72W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:9E:00.0 Off |                    0 |
| N/A   31C    P0              71W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:9F:00.0 Off |                    0 |
| N/A   35C    P0              73W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

hr0nix avatar Oct 06 '23 14:10 hr0nix

I strugle with the same issue.

Karina1997 avatar Oct 13 '23 12:10 Karina1997

Also relevant for me, would be great to have it solved.

minotru avatar Oct 16 '23 10:10 minotru

@jakevdp Hey, sorry for mentioning you directly, but this issue hasn't received any attention for several weeks. Can someone from the jax team please take a look? Thanks!

hr0nix avatar Oct 30 '23 12:10 hr0nix

@froystig Hey, sorry for mentioning you directly, but can someone take a look at this issue? It's a big blocker for me.

hr0nix avatar Nov 05 '23 16:11 hr0nix

I've made a repro for this bug. Turns out it has nothing to do with jax_threefry_partitionable, perfectly repoducible without it.

Repo was made for A100 80Gb, so tensor shapes might need to be adjusted for a GPU with a different amount of memory.

./repro.py — will fail because without rematerialization it needs ~122.75 Gb of GPU RAM ./repro.py --remat — works perfectly fine with remat, because it now needs just 63Gb of GPU RAM ./repro.py --remat --dropout-rate 0.1 — OOMs again, requiring ~118Gb of GPU RAM. From looking at peak buffers it becomes clear that the dropout mask is not being rematerialized: tensors correponding to full dropout masks for different layers are occupying memory.

Peak buffers:
        Buffer 1:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_5._apply_block/Block_5/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================
        Buffer 2:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_4._apply_block/Block_4/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================

        Buffer 3:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_4._apply_block/Block_4/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================

        Buffer 4:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_3._apply_block/Block_3/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================

        Buffer 5:
                Size: 512.00MiB
                Operator: op_name="jit(train_step)/jit(main)/jvp(Model)/Block_3._apply_block/Block_3/Dropout_0/jit(_bernoulli)/jit(_uniform)/threefry2x32" source_file="/papyrax/./tools/repro.py" source_line=24
                XLA Label: custom-call
                Shape: u32[134217728]
                ==========================
...

Repro code:

import functools

import click
import flax
import flax.linen as nn
import flax.training.train_state
import jax
import jax.numpy as jnp
import optax


class Dropout(nn.Module):
    rate: float

    @nn.compact
    def __call__(self, inputs, rng):
        if self.rate == 0.0:
            return inputs

        if self.rate == 1.0:
            return jnp.zeros_like(inputs)

        keep_prob = 1.0 - self.rate
        mask = jax.random.bernoulli(rng, p=keep_prob, shape=inputs.shape)
        return jax.lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))


class Block(nn.Module):
    dim: int
    dropout_rate: float

    @nn.compact
    def __call__(self, input, rng):
        scale = 32  # We want large memory consumption without remat
        emb = nn.Dense(features=self.dim * scale)(input)
        emb = nn.relu(emb)
        emb = Dropout(rate=self.dropout_rate)(emb, rng)
        emb = nn.Dense(features=self.dim)(emb)
        return emb


class Model(nn.Module):
    dim: int
    dropout_rate: float
    num_layers: int
    remat: bool

    @nn.compact
    def __call__(self, input, rng):
        def _apply_block(block, block_input, rng):
            return block(block_input, rng)

        if self.remat:
            _apply_block = nn.checkpoint(
                _apply_block,
                policy=jax.checkpoint_policies.nothing_saveable,
                prevent_cse=True,
            )

        emb = input
        for _ in range(self.num_layers):
            rng, block_rng = jax.random.split(rng)
            block = Block(dim=self.dim, dropout_rate=self.dropout_rate)
            emb = _apply_block(block, emb, block_rng)

        return emb


def loss_fn(params, train_state, batch, rng):
    outputs = train_state.apply_fn(params, batch, rng)
    return jnp.mean(outputs * outputs)


@functools.partial(jax.jit, donate_argnames=("train_state",))
def train_step(train_state, batch, rng):
    grad_fn = jax.grad(loss_fn)
    grad = grad_fn(train_state.params, train_state, batch, rng)
    train_state = train_state.apply_gradients(grads=grad)
    return train_state


def make_batch(batch_size, dim):
    return jnp.zeros(shape=(batch_size, dim), dtype=jnp.float32)


@click.command()
@click.option("--dim", type=int, default=1024)
@click.option("--batch-size", type=int, default=8192)
@click.option("--dropout-rate", type=float, default=0.0)
@click.option("--num-layers", type=int, default=64)
@click.option("--remat", is_flag=True)
def main(
    dim: int,
    batch_size: int,
    dropout_rate: float,
    num_layers: int,
    remat: bool,
):
    model = Model(
        dim=dim, dropout_rate=dropout_rate, num_layers=num_layers, remat=remat
    )
    batch = make_batch(batch_size=batch_size, dim=dim)
    rng = jax.random.PRNGKey(0)
    params = model.init({"params": rng}, batch, rng)
    optimizer = optax.adam(learning_rate=1e-3)
    train_state = flax.training.train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=optimizer
    )
    train_state = train_step(train_state, batch, rng)


if __name__ == "__main__":
    main()

hr0nix avatar Nov 05 '23 20:11 hr0nix

Thanks for the repro!

froystig avatar Nov 06 '23 18:11 froystig

@hr0nix sorry that this slipped through the cracks. Thanks for the pings, everyone.

Can you check that this repros with jaxlib 0.4.20? IIRC there was one GPU-specific remat fix that happened recently, though I don't have a link to it at the moment. EDIT: https://github.com/openxla/xla/pull/6527

mattjj avatar Nov 06 '23 18:11 mattjj

Thanks for the pointer!

Unfortunately, it looks like the problem is still present with jaxlib==0.4.20

hr0nix avatar Nov 06 '23 19:11 hr0nix

Thanks for checking.

I think our next step is to try to repro on TPU, to see if it's GPU-specific. We can do that on our end.

mattjj avatar Nov 07 '23 12:11 mattjj

Hey, any updates on this?

hr0nix avatar Nov 14 '23 21:11 hr0nix

@mattjj @froystig Happy new year, gentlemen! Do you think 2024 is the year when this bug finally got fixed? ;-)

hr0nix avatar Jan 02 '24 14:01 hr0nix

Ping!

hr0nix avatar Feb 26 '24 11:02 hr0nix

Hmm, looks like using jax_default_prng_impl=rbg fixes this issue.

hr0nix avatar Feb 26 '24 14:02 hr0nix

Hmm, looks like using jax_default_prng_impl=rbg fixes this issue.

Thanks, this is a useful additional bit of info. This is still in our queue, but we haven't dug in yet.

I understood your most recent comment to mean that you have a workaround. Is that right? At large scales, jax_default_prng_impl=rbg can be a good idea to try anyway, as it can drastically speed up compilation times.

froystig avatar Mar 03 '24 05:03 froystig

I understood your most recent comment to mean that you have a workaround. Is that right?

Looks like it. Interestingly, it also seems to fix another rng-related issue: https://github.com/google/jax/issues/19893

Btw, can you elaborate a bit on how does the rng implementation work when keys are sharded? E.g. does it require any additional communication?

hr0nix avatar Mar 03 '24 15:03 hr0nix

On GPU, for a fixed key, I do not expect that sharded number generation under rbg would require communication. E.g. I expect the following to print False:

import jax
import jax.numpy as jnp

@jax.jit
def f(key, x):
  numbers = jax.random.uniform(key, x.shape)
  return x + numbers

key = jax.random.key(42)
x_sharding = jax.sharding.PositionalSharding(jax.devices())
x = jax.device_put(jnp.arange(24.), x_sharding)

f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())

(and the same if we check for other collectives in the HLO.)

Meanwhile I also expect the output sharding of f(key, x) to be, e.g.:

PositionalSharding([{GPU 0} {GPU 1}], shape=(2,))

when jax.devices() is a list of two GPUs.

Your comment however asks "when keys are sharded." Do you mean that you are sharding a computation that vmaps a random number generation operation over a batch of keys (in the form of a sharded key array)? If so, then there's a current unrelated issue to watch specifically regarding vmap of rbg over keys, covered by #19085. The workaround there is not to vmap number generation over keys, but instead to hoist the generation step: draw the entire batch of random numbers from a single key outside of the vmapped function, and pass that in.

froystig avatar Mar 05 '24 03:03 froystig