neural-tangents icon indicating copy to clipboard operation
neural-tangents copied to clipboard

nt.linearize induces a CUDA_ERROR_OUT_OF_MEMORY error

Open RylanSchaeffer opened this issue 2 years ago • 7 comments

We're trying to fine-tune a linearized Vision Transformer by adapting code from https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb.

We're running into a really puzzling problem: when we load a model, we can train it, and when we linearize it, we can still train the pre-linearized model. However, when we try using the linearized model, we get:

RuntimeError: Internal: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory

This error emerges regardless of whether we are using 1 GPU or multiple. It also emerges whether we are using a large batch (512) or small (1).

We manually tested that a forward pass raises no error, and that a backward pass raises no error. We suspect that the error might arise from the following code (although we could be wrong!):

Their code:

def make_update_fn(*, apply_fn, accum_steps, lr_fn):
  """Returns update step for data parallel training."""

  def update_fn(opt, step, batch, rng):

    _, new_rng = jax.random.split(rng)
    # Bind the rng key to the device id (which is unique across hosts)
    # Note: This is only used for multi-host training (i.e. multiple computers
    # each with multiple accelerators).
    dropout_rng = jax.random.fold_in(rng, jax.lax.axis_index('batch'))

    def cross_entropy_loss(*, logits, labels):
      logp = jax.nn.log_softmax(logits)
      return -jnp.mean(jnp.sum(logp * labels, axis=1))

    def loss_fn(params, images, labels):
      logits = apply_fn(
          dict(params=params),
          rngs=dict(dropout=dropout_rng),
          inputs=images,
          train=True)
      return cross_entropy_loss(logits=logits, labels=labels)

    l, g = utils.accumulate_gradient(
        jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
        accum_steps)
    g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
    l = jax.lax.pmean(l, axis_name='batch')

    opt = opt.apply_gradient(g, learning_rate=lr_fn(step))
    return opt, l, new_rng

  return jax.pmap(update_fn, axis_name='batch', donate_argnums=(0,))

That function is then called via:

# Check out train.make_update_fn in the editor on the right side for details.
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
update_fn_repl = train.make_update_fn(
    apply_fn=vit_apply, accum_steps=accum_steps, lr_fn=lr_fn)
# We use a momentum optimizer that uses half precision for state to save
# memory. It als implements the gradient clipping.
opt = momentum_clip.Optimizer(grad_norm_clip=grad_norm_clip).create(params)
opt_repl = flax.jax_utils.replicate(opt)

The training loop where the memory error arises:

losses = []
lrs = []
# Completes in ~20 min on the TPU runtime.
for step, batch in zip(
    tqdm.trange(1, total_steps + 1),
    ds_train.as_numpy_iterator(),
):

  opt_repl, loss_repl, update_rng_repl = update_fn_repl(
      opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)  # ERROR IS HERE
  losses.append(loss_repl[0])
  lrs.append(lr_fn(step))

The above code is all copied from the ViT repo. This is how we linearize the ViT model:

def vit_apply(params, input):
  return model.apply(dict(params=params), input, train=True)
f_lin = nt.linearize(vit_apply, params)

RylanSchaeffer avatar Mar 05 '22 22:03 RylanSchaeffer

@romanngg , you were really helpful previously - any thoughts here? Thanks in advance :)

RylanSchaeffer avatar Mar 05 '22 22:03 RylanSchaeffer

nt.linearize is essentially a Jacobian-vector product (jax.jvp), and it's peak memory consumption of the linearized forward pass should be about 2x the peak memory consumption of the forward pass. Then, I believe the costs of the backward passes (jax.vjp) of the linearized and non-linearized models should also differ by 2X (see https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#how-it-s-made-two-foundational-autodiff-functions or https://openreview.net/pdf?id=ym68T6OoO6L). If you have a way to diagnose the peak memory consumption when you train your model, could you check that it's less than half of your GPU memory?

romanngg avatar Mar 06 '22 01:03 romanngg

@romanngg thanks for getting back to me so soon! I'll check the max memory consumption but I don't think that's the reason because we could successfully "manually" perform a forward and backward pass of f_lin on a single GPU with batch size = 512. By "manually," I mean executing the following alone:

    l, g = utils.accumulate_gradient(
        jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
        **accum_steps)**

I suspect that there might be an odd interaction between f_lin as constructed by Neural Tangents and the code used in the vision transformer notebook (pasted above).

RylanSchaeffer avatar Mar 06 '22 19:03 RylanSchaeffer

~~Another bizarre observation: if we try~~

    l, g = utils.accumulate_gradient(
        jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
        **accum_steps)**

~~with the original (non-linearized) model outside the update_fn() (defined above), we get a OOM error for about 155 MiB, even though the GPU has tons of additional available memory. This problem does not occur when using the linearized model.~~

Edit: Ignore that last observation. That problem vanished when we reduced the per-GPU batch size..

RylanSchaeffer avatar Mar 07 '22 17:03 RylanSchaeffer

Here's a self-contained colab that reproduces the issue. https://colab.research.google.com/drive/184moQLq3tjo-wEpc8gD7fXCFguAVDBOm#scrollTo=k4CjYqp5qLvj

We suspect that pmap might be causing a problem because if we don't use it, the linearized model can train via jax.value_and_grad(loss_fn), but once we try pmap(jax.value_and_grad(loss_fn)), we hit OOM.

RylanSchaeffer avatar Mar 07 '22 19:03 RylanSchaeffer

Another insight: GPU on Colab breaks, but TPU on Colab is fine

RylanSchaeffer avatar Mar 07 '22 20:03 RylanSchaeffer

Looks like someone else has a similar problem while using neural tangents, also potentially arising from pmap

https://github.com/google/jax/issues/8585#issuecomment-1061256273

RylanSchaeffer avatar Mar 07 '22 23:03 RylanSchaeffer