neural-tangents
neural-tangents copied to clipboard
nt.linearize induces a CUDA_ERROR_OUT_OF_MEMORY error
We're trying to fine-tune a linearized Vision Transformer by adapting code from https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb.
We're running into a really puzzling problem: when we load a model, we can train it, and when we linearize it, we can still train the pre-linearized model. However, when we try using the linearized model, we get:
RuntimeError: Internal: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory
This error emerges regardless of whether we are using 1 GPU or multiple. It also emerges whether we are using a large batch (512) or small (1).
We manually tested that a forward pass raises no error, and that a backward pass raises no error. We suspect that the error might arise from the following code (although we could be wrong!):
Their code:
def make_update_fn(*, apply_fn, accum_steps, lr_fn):
"""Returns update step for data parallel training."""
def update_fn(opt, step, batch, rng):
_, new_rng = jax.random.split(rng)
# Bind the rng key to the device id (which is unique across hosts)
# Note: This is only used for multi-host training (i.e. multiple computers
# each with multiple accelerators).
dropout_rng = jax.random.fold_in(rng, jax.lax.axis_index('batch'))
def cross_entropy_loss(*, logits, labels):
logp = jax.nn.log_softmax(logits)
return -jnp.mean(jnp.sum(logp * labels, axis=1))
def loss_fn(params, images, labels):
logits = apply_fn(
dict(params=params),
rngs=dict(dropout=dropout_rng),
inputs=images,
train=True)
return cross_entropy_loss(logits=logits, labels=labels)
l, g = utils.accumulate_gradient(
jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
accum_steps)
g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
l = jax.lax.pmean(l, axis_name='batch')
opt = opt.apply_gradient(g, learning_rate=lr_fn(step))
return opt, l, new_rng
return jax.pmap(update_fn, axis_name='batch', donate_argnums=(0,))
That function is then called via:
# Check out train.make_update_fn in the editor on the right side for details.
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
update_fn_repl = train.make_update_fn(
apply_fn=vit_apply, accum_steps=accum_steps, lr_fn=lr_fn)
# We use a momentum optimizer that uses half precision for state to save
# memory. It als implements the gradient clipping.
opt = momentum_clip.Optimizer(grad_norm_clip=grad_norm_clip).create(params)
opt_repl = flax.jax_utils.replicate(opt)
The training loop where the memory error arises:
losses = []
lrs = []
# Completes in ~20 min on the TPU runtime.
for step, batch in zip(
tqdm.trange(1, total_steps + 1),
ds_train.as_numpy_iterator(),
):
opt_repl, loss_repl, update_rng_repl = update_fn_repl(
opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl) # ERROR IS HERE
losses.append(loss_repl[0])
lrs.append(lr_fn(step))
The above code is all copied from the ViT repo. This is how we linearize the ViT model:
def vit_apply(params, input):
return model.apply(dict(params=params), input, train=True)
f_lin = nt.linearize(vit_apply, params)
@romanngg , you were really helpful previously - any thoughts here? Thanks in advance :)
nt.linearize
is essentially a Jacobian-vector product (jax.jvp
), and it's peak memory consumption of the linearized forward pass should be about 2x the peak memory consumption of the forward pass. Then, I believe the costs of the backward passes (jax.vjp
) of the linearized and non-linearized models should also differ by 2X (see https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#how-it-s-made-two-foundational-autodiff-functions or https://openreview.net/pdf?id=ym68T6OoO6L). If you have a way to diagnose the peak memory consumption when you train your model, could you check that it's less than half of your GPU memory?
@romanngg thanks for getting back to me so soon! I'll check the max memory consumption but I don't think that's the reason because we could successfully "manually" perform a forward and backward pass of f_lin
on a single GPU with batch size = 512. By "manually," I mean executing the following alone:
l, g = utils.accumulate_gradient(
jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
**accum_steps)**
I suspect that there might be an odd interaction between f_lin
as constructed by Neural Tangents and the code used in the vision transformer notebook (pasted above).
~~Another bizarre observation: if we try~~
l, g = utils.accumulate_gradient(
jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
**accum_steps)**
~~with the original (non-linearized) model outside the update_fn()
(defined above), we get a OOM error for about 155 MiB, even though the GPU has tons of additional available memory. This problem does not occur when using the linearized model.~~
Edit: Ignore that last observation. That problem vanished when we reduced the per-GPU batch size..
Here's a self-contained colab that reproduces the issue. https://colab.research.google.com/drive/184moQLq3tjo-wEpc8gD7fXCFguAVDBOm#scrollTo=k4CjYqp5qLvj
We suspect that pmap
might be causing a problem because if we don't use it, the linearized model can train via jax.value_and_grad(loss_fn)
, but once we try pmap(jax.value_and_grad(loss_fn))
, we hit OOM.
Another insight: GPU on Colab breaks, but TPU on Colab is fine
Looks like someone else has a similar problem while using neural tangents, also potentially arising from pmap
https://github.com/google/jax/issues/8585#issuecomment-1061256273