long-range-arena
long-range-arena copied to clipboard
Is it possible to include instructions on how to run it on GPUs
This code seems to be using 4x4 TPUs, but since I don't have access to TPUs, I wonder if you could release instructions on how to replicater the results on GPUs, which would make this code more accessible for people without abundant computation resources.
The jax code should also run on GPUs. We have tested this on a virtual machine on google cloud so it should work without any special instructions.
thanks for the reply! But it would throw an OOM error on a single Titan X GPU, it'd be nice if there's a flag like accumulate-gradients/update-freq to be able to reproduce the results on a single GPU. (sorry if this is a dumb question, but I'm not very familiar with tensorflow/jax)
Thanks for the feedback!
@ppham27 ran this on the cloud vm, so I'm looping him in and wondering if he has any thoughts on this.
A single Titan X doesn't have enough HBM. For our GPU setup, we had 8 V100s for a total of 128GB of HBM. For a single Titan X, I think you could max out at batch size of 3, which is probably, too small. Adding an outer loop and doing gradient accumulation is probably the right way to address this. If there's a lot of interest in being able to train on a single GPU, we can look into this.
Having a way to turn on accumulate-gradients/update-freq would be amazing for reproducibility on GPUs. What is the best approach for doing this in JAX?
@MostafaDehghani has an example for this. Do you mind sharing it?
Hi, thanks for the question.
Yes. I also think using gradient accumulation is the way to go. Here is an example of implementing it in JAX, which we used in another project, but I'm sure it's easily portable to LRA. https://github.com/google-research/vision_transformer/blob/master/vit_jax/train.py#L63
Adding gradient accumulation to LRA is in our TODO list, but currently there a few higher priority fixes/features requests that we should take care of. In the meantime, a PR that adds it to our training loops is extremely welcome :)
Hi, Mostafa!
Thank you for the quick response. I was able to adapt your code for text classification and it seems like the gradient accumulation is working fine. Since jax.fori_loop requires that the input and output remain with the same type and shape, I couldn't stack logits during accumulation. I've circumvented this by getting the logits later. Here is the code:
def train_step(optimizer, batch, learning_rate_fn, accum_steps, dropout_rng=None):
train_keys = ['inputs', 'targets']
(inputs, targets) = [batch.get(k, None) for k in train_keys]
dropout_rng, new_dropout_rng = random.split(dropout_rng)
def loss_fn(model, x, y):
"""Loss function used for training."""
with nn.stochastic(dropout_rng):
logits = model(x, train=True)
loss, weight_sum = train_utils.compute_weighted_cross_entropy(
logits, y, num_classes=CLASS_MAP[FLAGS.task_name], weights=None)
mean_loss = loss / weight_sum
return mean_loss, logits
step = optimizer.state.step
lr = learning_rate_fn(step)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
# compute gradients and get logits
_, grad = accumulate_gradient(grad_fn, optimizer.target, inputs, targets, accum_steps)
grad = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), grad)
logits = optimizer.target(inputs, train=False)
# to save memory:
# logits = optimizer.target(inputs[0][jnp.newaxis, ...], train=False)
# for i in range(1, inputs.shape[0]):
# y_hat = optimizer.target(inputs[i][jnp.newaxis, ...], train=False)
# logits = jnp.concatenate((logits, y_hat), axis=0)
new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
metrics = compute_metrics(logits, targets, None)
metrics['learning_rate'] = lr
return new_optimizer, metrics, new_dropout_rng
def accumulate_gradient(loss_and_grad_fn, params, inputs, labels, accum_steps):
"""Accumulate gradient over multiple steps to save on memory."""
if accum_steps and accum_steps > 1:
assert inputs.shape[0] % accum_steps == 0, (
f'Bad accum_steps {accum_steps} for batch size {inputs.shape[0]}')
step_size = inputs.shape[0] // accum_steps
(l, _), g = loss_and_grad_fn(params, inputs[:step_size], labels[:step_size])
def acc_grad_and_loss(i, l_and_g):
inps = jax.lax.dynamic_slice(inputs, (i * step_size, 0),
(step_size,) + inputs.shape[1:])
lbls = jax.lax.dynamic_slice(labels[..., jnp.newaxis], (i * step_size, 1),
(step_size, 1)).squeeze(axis=-1)
(li, _), gi = loss_and_grad_fn(params, inps, lbls)
l, g = l_and_g
return l + li, jax.tree_multimap(lambda x, y: x + y, g, gi)
l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
l, g = jax.tree_map(lambda x: x / accum_steps, (l, g))
return l, g
else:
return loss_and_grad_fn(params, inputs, labels)
Hi! I got the following results on the test set by using a single GPU (24GB) and setting accum_steps=batch_size. All hyperparameters were kept intact, and the only thing that changed in the training procedure was the gradient accumulation part.
1 GPU TPU (paper) Chance (baseline)
ListOps 0.1830 0.3637 0.10
Doc Class. 0.6323 0.6427 0.50
Retrieval 0.4752 0.5746 0.50
@vanzytay @MostafaDehghani Any idea on why?
Best,
I'm also running into memory issues. I've given up on the vanilla Transformer (this is a benchmark for efficient Transformers, after all), but even for the Performer, I need 2× Tesla V100 (32GB each).
Do you think it's possible to reproduce your results with, say, a batch size of 16 or 8 (and without changing the code)?
In Table 2 you given some insights on the 'peak memory usage' per device with a batch size of 32. Do you refer to an effective batch size of 32 or to a batch size of 32 per device?
Can I expect to have a similar memory consumption on a single GPU with a batch size of 32 or 2?
Hi, Mostafa!
Thank you for the quick response. I was able to adapt your code for text classification and it seems like the gradient accumulation is working fine. Since
jax.fori_looprequires that the input and output remain with the same type and shape, I couldn't stack logits during accumulation. I've circumvented this by getting the logits later. Here is the code:def train_step(optimizer, batch, learning_rate_fn, accum_steps, dropout_rng=None): train_keys = ['inputs', 'targets'] (inputs, targets) = [batch.get(k, None) for k in train_keys] dropout_rng, new_dropout_rng = random.split(dropout_rng) def loss_fn(model, x, y): """Loss function used for training.""" with nn.stochastic(dropout_rng): logits = model(x, train=True) loss, weight_sum = train_utils.compute_weighted_cross_entropy( logits, y, num_classes=CLASS_MAP[FLAGS.task_name], weights=None) mean_loss = loss / weight_sum return mean_loss, logits step = optimizer.state.step lr = learning_rate_fn(step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) # compute gradients and get logits _, grad = accumulate_gradient(grad_fn, optimizer.target, inputs, targets, accum_steps) grad = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), grad) logits = optimizer.target(inputs, train=False) # to save memory: # logits = optimizer.target(inputs[0][jnp.newaxis, ...], train=False) # for i in range(1, inputs.shape[0]): # y_hat = optimizer.target(inputs[i][jnp.newaxis, ...], train=False) # logits = jnp.concatenate((logits, y_hat), axis=0) new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr) metrics = compute_metrics(logits, targets, None) metrics['learning_rate'] = lr return new_optimizer, metrics, new_dropout_rng def accumulate_gradient(loss_and_grad_fn, params, inputs, labels, accum_steps): """Accumulate gradient over multiple steps to save on memory.""" if accum_steps and accum_steps > 1: assert inputs.shape[0] % accum_steps == 0, ( f'Bad accum_steps {accum_steps} for batch size {inputs.shape[0]}') step_size = inputs.shape[0] // accum_steps (l, _), g = loss_and_grad_fn(params, inputs[:step_size], labels[:step_size]) def acc_grad_and_loss(i, l_and_g): inps = jax.lax.dynamic_slice(inputs, (i * step_size, 0), (step_size,) + inputs.shape[1:]) lbls = jax.lax.dynamic_slice(labels[..., jnp.newaxis], (i * step_size, 1), (step_size, 1)).squeeze(axis=-1) (li, _), gi = loss_and_grad_fn(params, inps, lbls) l, g = l_and_g return l + li, jax.tree_multimap(lambda x, y: x + y, g, gi) l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g)) l, g = jax.tree_map(lambda x: x / accum_steps, (l, g)) return l, g else: return loss_and_grad_fn(params, inputs, labels)
hi, i've also met the OOM problem with a V100 32GB card, really need the gradient accumulation in your implementation, the state variable is missing below is the original loss_fn code
def loss_fn(model, inputs, targets):
with nn.stateful(state) as new_state:
with nn.stochastic(dropout_rng):
logits = model(inputs, train=True)
...
return mean_loss, (new_state, logits)
and the returned new_state is used for the next train_step by the train_loop method
for step, batch in zip(range(start_step, num_train_steps), train_iter):
batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access
optimizer, state, metrics, dropout_rngs = p_train_step(
optimizer, state, batch, dropout_rng=dropout_rngs)
Would simply deleting this variable as in your implementation cause some problem in the training ?
I'd also like to duplicate @La-SilverLand question. Currently I'm trying to fit the Pathfinder model code into a V100 GPU, and you have provided all tools for that except the answer about nn.stateful. I'm very new to JAX, so I can't tell, will it cripple the training process if I remove it.
Sorry for the delay in my reply to this issue.
@EternalSorrrow, as long as you don't have anything that requires keeping some global statistics, (like BatchNorm) in your model, you can just delete the usage of state and nn.stateful and it should be all good.
If you needed a ResNet baseline that has BatchNorm, I recommend using the version with GroupNorm to avoid complication of handling batch statistic when using gradient accumulation.
Thanks for response. It seems that in this case Transformer implementations in the repo should be fine (at least most of them) -- LayerNorms won't use batch-wise statistics.