long-range-arena icon indicating copy to clipboard operation
long-range-arena copied to clipboard

Is it possible to include instructions on how to run it on GPUs

Open da03 opened this issue 5 years ago • 15 comments

This code seems to be using 4x4 TPUs, but since I don't have access to TPUs, I wonder if you could release instructions on how to replicater the results on GPUs, which would make this code more accessible for people without abundant computation resources.

da03 avatar Nov 16 '20 04:11 da03

The jax code should also run on GPUs. We have tested this on a virtual machine on google cloud so it should work without any special instructions.

vanzytay avatar Nov 16 '20 05:11 vanzytay

thanks for the reply! But it would throw an OOM error on a single Titan X GPU, it'd be nice if there's a flag like accumulate-gradients/update-freq to be able to reproduce the results on a single GPU. (sorry if this is a dumb question, but I'm not very familiar with tensorflow/jax)

da03 avatar Dec 02 '20 22:12 da03

Thanks for the feedback!

@ppham27 ran this on the cloud vm, so I'm looping him in and wondering if he has any thoughts on this.

vanzytay avatar Dec 03 '20 01:12 vanzytay

A single Titan X doesn't have enough HBM. For our GPU setup, we had 8 V100s for a total of 128GB of HBM. For a single Titan X, I think you could max out at batch size of 3, which is probably, too small. Adding an outer loop and doing gradient accumulation is probably the right way to address this. If there's a lot of interest in being able to train on a single GPU, we can look into this.

ppham27 avatar Dec 03 '20 03:12 ppham27

Having a way to turn on accumulate-gradients/update-freq would be amazing for reproducibility on GPUs. What is the best approach for doing this in JAX?

mtreviso avatar Dec 10 '20 15:12 mtreviso

@MostafaDehghani has an example for this. Do you mind sharing it?

vanzytay avatar Dec 10 '20 15:12 vanzytay

Hi, thanks for the question.

Yes. I also think using gradient accumulation is the way to go. Here is an example of implementing it in JAX, which we used in another project, but I'm sure it's easily portable to LRA. https://github.com/google-research/vision_transformer/blob/master/vit_jax/train.py#L63

Adding gradient accumulation to LRA is in our TODO list, but currently there a few higher priority fixes/features requests that we should take care of. In the meantime, a PR that adds it to our training loops is extremely welcome :)

MostafaDehghani avatar Dec 10 '20 22:12 MostafaDehghani

Hi, Mostafa!

Thank you for the quick response. I was able to adapt your code for text classification and it seems like the gradient accumulation is working fine. Since jax.fori_loop requires that the input and output remain with the same type and shape, I couldn't stack logits during accumulation. I've circumvented this by getting the logits later. Here is the code:

def train_step(optimizer, batch, learning_rate_fn, accum_steps, dropout_rng=None):
  train_keys = ['inputs', 'targets']
  (inputs, targets) = [batch.get(k, None) for k in train_keys]
  dropout_rng, new_dropout_rng = random.split(dropout_rng)

  def loss_fn(model, x, y):
    """Loss function used for training."""
    with nn.stochastic(dropout_rng):
      logits = model(x, train=True)
    loss, weight_sum = train_utils.compute_weighted_cross_entropy(
        logits, y, num_classes=CLASS_MAP[FLAGS.task_name], weights=None)
    mean_loss = loss / weight_sum
    return mean_loss, logits

  step = optimizer.state.step
  lr = learning_rate_fn(step)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

  # compute gradients and get logits
  _, grad = accumulate_gradient(grad_fn, optimizer.target, inputs, targets, accum_steps)
  grad = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), grad)
  logits = optimizer.target(inputs, train=False)
  # to save memory:
  # logits = optimizer.target(inputs[0][jnp.newaxis, ...], train=False)
  # for i in range(1, inputs.shape[0]):
  #   y_hat = optimizer.target(inputs[i][jnp.newaxis, ...], train=False)
  #   logits = jnp.concatenate((logits, y_hat), axis=0)

  new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
  metrics = compute_metrics(logits, targets, None)
  metrics['learning_rate'] = lr

  return new_optimizer, metrics, new_dropout_rng

def accumulate_gradient(loss_and_grad_fn, params, inputs, labels, accum_steps):
  """Accumulate gradient over multiple steps to save on memory."""
  if accum_steps and accum_steps > 1:
    assert inputs.shape[0] % accum_steps == 0, (
        f'Bad accum_steps {accum_steps} for batch size {inputs.shape[0]}')
    step_size = inputs.shape[0] // accum_steps
    (l, _), g = loss_and_grad_fn(params, inputs[:step_size], labels[:step_size])

    def acc_grad_and_loss(i, l_and_g):
      inps = jax.lax.dynamic_slice(inputs, (i * step_size, 0),
                                   (step_size,) + inputs.shape[1:])
      lbls = jax.lax.dynamic_slice(labels[..., jnp.newaxis], (i * step_size, 1),
                                   (step_size, 1)).squeeze(axis=-1)
      (li, _), gi = loss_and_grad_fn(params, inps, lbls)
      l, g = l_and_g
      return l + li, jax.tree_multimap(lambda x, y: x + y, g, gi)

    l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
    l, g = jax.tree_map(lambda x: x / accum_steps, (l, g))
    return l, g
  
  else:
    return loss_and_grad_fn(params, inputs, labels)

mtreviso avatar Dec 13 '20 02:12 mtreviso

Hi! I got the following results on the test set by using a single GPU (24GB) and setting accum_steps=batch_size. All hyperparameters were kept intact, and the only thing that changed in the training procedure was the gradient accumulation part.

              1 GPU        TPU (paper)     Chance (baseline)
ListOps       0.1830       0.3637          0.10
Doc Class.    0.6323       0.6427          0.50
Retrieval     0.4752       0.5746          0.50

@vanzytay @MostafaDehghani Any idea on why?

Best,

mtreviso avatar Dec 14 '20 21:12 mtreviso

I'm also running into memory issues. I've given up on the vanilla Transformer (this is a benchmark for efficient Transformers, after all), but even for the Performer, I need 2× Tesla V100 (32GB each).

Do you think it's possible to reproduce your results with, say, a batch size of 16 or 8 (and without changing the code)?

cifkao avatar Jan 26 '21 10:01 cifkao

In Table 2 you given some insights on the 'peak memory usage' per device with a batch size of 32. Do you refer to an effective batch size of 32 or to a batch size of 32 per device?

Can I expect to have a similar memory consumption on a single GPU with a batch size of 32 or 2?

GregorKobsik avatar Mar 05 '21 20:03 GregorKobsik

Hi, Mostafa!

Thank you for the quick response. I was able to adapt your code for text classification and it seems like the gradient accumulation is working fine. Since jax.fori_loop requires that the input and output remain with the same type and shape, I couldn't stack logits during accumulation. I've circumvented this by getting the logits later. Here is the code:

def train_step(optimizer, batch, learning_rate_fn, accum_steps, dropout_rng=None):
  train_keys = ['inputs', 'targets']
  (inputs, targets) = [batch.get(k, None) for k in train_keys]
  dropout_rng, new_dropout_rng = random.split(dropout_rng)

  def loss_fn(model, x, y):
    """Loss function used for training."""
    with nn.stochastic(dropout_rng):
      logits = model(x, train=True)
    loss, weight_sum = train_utils.compute_weighted_cross_entropy(
        logits, y, num_classes=CLASS_MAP[FLAGS.task_name], weights=None)
    mean_loss = loss / weight_sum
    return mean_loss, logits

  step = optimizer.state.step
  lr = learning_rate_fn(step)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

  # compute gradients and get logits
  _, grad = accumulate_gradient(grad_fn, optimizer.target, inputs, targets, accum_steps)
  grad = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), grad)
  logits = optimizer.target(inputs, train=False)
  # to save memory:
  # logits = optimizer.target(inputs[0][jnp.newaxis, ...], train=False)
  # for i in range(1, inputs.shape[0]):
  #   y_hat = optimizer.target(inputs[i][jnp.newaxis, ...], train=False)
  #   logits = jnp.concatenate((logits, y_hat), axis=0)

  new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
  metrics = compute_metrics(logits, targets, None)
  metrics['learning_rate'] = lr

  return new_optimizer, metrics, new_dropout_rng

def accumulate_gradient(loss_and_grad_fn, params, inputs, labels, accum_steps):
  """Accumulate gradient over multiple steps to save on memory."""
  if accum_steps and accum_steps > 1:
    assert inputs.shape[0] % accum_steps == 0, (
        f'Bad accum_steps {accum_steps} for batch size {inputs.shape[0]}')
    step_size = inputs.shape[0] // accum_steps
    (l, _), g = loss_and_grad_fn(params, inputs[:step_size], labels[:step_size])

    def acc_grad_and_loss(i, l_and_g):
      inps = jax.lax.dynamic_slice(inputs, (i * step_size, 0),
                                   (step_size,) + inputs.shape[1:])
      lbls = jax.lax.dynamic_slice(labels[..., jnp.newaxis], (i * step_size, 1),
                                   (step_size, 1)).squeeze(axis=-1)
      (li, _), gi = loss_and_grad_fn(params, inps, lbls)
      l, g = l_and_g
      return l + li, jax.tree_multimap(lambda x, y: x + y, g, gi)

    l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g))
    l, g = jax.tree_map(lambda x: x / accum_steps, (l, g))
    return l, g
  
  else:
    return loss_and_grad_fn(params, inputs, labels)

hi, i've also met the OOM problem with a V100 32GB card, really need the gradient accumulation in your implementation, the state variable is missing below is the original loss_fn code

def loss_fn(model, inputs, targets):
    with nn.stateful(state) as new_state:
      with nn.stochastic(dropout_rng):
        logits = model(inputs, train=True)
...
return mean_loss, (new_state, logits)

and the returned new_state is used for the next train_step by the train_loop method

  for step, batch in zip(range(start_step, num_train_steps), train_iter):
    batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))  # pylint: disable=protected-access
    optimizer, state, metrics, dropout_rngs = p_train_step(
        optimizer, state, batch, dropout_rng=dropout_rngs)

Would simply deleting this variable as in your implementation cause some problem in the training ?

La-SilverLand avatar Aug 12 '21 08:08 La-SilverLand

I'd also like to duplicate @La-SilverLand question. Currently I'm trying to fit the Pathfinder model code into a V100 GPU, and you have provided all tools for that except the answer about nn.stateful. I'm very new to JAX, so I can't tell, will it cripple the training process if I remove it.

vladyorsh avatar Oct 07 '21 11:10 vladyorsh

Sorry for the delay in my reply to this issue. @EternalSorrrow, as long as you don't have anything that requires keeping some global statistics, (like BatchNorm) in your model, you can just delete the usage of state and nn.stateful and it should be all good.

If you needed a ResNet baseline that has BatchNorm, I recommend using the version with GroupNorm to avoid complication of handling batch statistic when using gradient accumulation.

MostafaDehghani avatar Oct 07 '21 11:10 MostafaDehghani

Thanks for response. It seems that in this case Transformer implementations in the repo should be fine (at least most of them) -- LayerNorms won't use batch-wise statistics.

vladyorsh avatar Oct 07 '21 19:10 vladyorsh