transformers icon indicating copy to clipboard operation
transformers copied to clipboard

TF2 DeBERTaV2 runs super slow on TPUs

Open WissamAntoun opened this issue 2 years ago • 33 comments

System Info

latest version of transformers, Colab TPU, tensorflow 2

Who can help?

@kamalkraj @Rocketknight1 @BigBird01

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

It's currently hard to share code and access to the google bucket. But I believe any TF2 DeBERTaV2 code running on TPUs will have this issue

Expected behavior

I've been trying to train a deberta v3 model on GPU and TPUs. I got it to work on multi-node and multi-gpus using Nvidia deeplearning examples libraries https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/ I basically used the training setup and loop from the BERT code, the dataset utils from the ELECTRA code, and the model from Huggingface transformers with some changes in order to share embeddings.

On 6xA40 45gb gpus i get around 1370 sentences per seconds during training (which is lower than what Nvidia gets for Electra but it's fine).

Ok, now the problem.... on TPU i get 20 sentences per second

I traced the issue back to the tf.gather function here https://github.com/huggingface/transformers/blob/main/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py#L525

I ran TPU profiling and this is the output: image

GatherV2 takes most of the time: image

zoomed in pictures of the fast ops image

Also, I'm not sure if this is TPU specific since on GPUs the training ~30% slower compared to regular ELECTRA.

WissamAntoun avatar Jul 21 '22 15:07 WissamAntoun

Hi @WissamAntoun, this is an interesting issue! I honestly have no idea what the cause could be, but the fact that it highlights that function is interesting. The reason is that the DeBERTa code was ported from PyTorch, and so we wrote our own implementation of take_along_axis because TF didn't have one. One thing to try would be to edit the code to use tf.experimental.numpy.take_along_axis instead of that function. If that doesn't work then we might have to see if we can do things in a different, more performant way.

Also, just in case XLA compilation is the issue, have you tried using jit_compile=True in compile() when running DeBERTa on GPU? If that also causes performance degradation then the problem is caused by XLA and not TPUs, and we can investigate from there.

Rocketknight1 avatar Jul 21 '22 16:07 Rocketknight1

Also cc @sanchit-gandhi because I'm not a TPU expert - don't worry about investigating this deeply, but if anything comes to mind when you read it, let me know!

Rocketknight1 avatar Jul 21 '22 16:07 Rocketknight1

@Rocketknight1 I read all the discussions that you had with Kamal about the torch.gather and take_along_axis .

On GPUs I already enabled XLA via tf.config.optimizer.set_jit and via TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" but I was reading that this isn't the optimal way to do it, so I'm now trying the jit_compile=True and will report back.

Also I just finished testing tf.experimental.numpy.take_along_axis, on GPUs it improved performance by ~10% yet on TPUs I still have the same issue. I will also test the jit_compile on TPUs but I don't think it will solve anything.

Thanks a lot for the replies and for the effort you put in convert the pytorch code into TF

WissamAntoun avatar Jul 21 '22 18:07 WissamAntoun

runnig the training with jit_compile=True on GPU revealed a new bug. Then it is now an XLA/JIT issue not a TPU one

View log dump

2022-07-21 23:36:18.107830: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at bcast_ops.cc:50 : 
INVALID_ARGUMENT: 
Input 0 to node `pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs`
with op BroadcastArgs must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. 
This error means that a shape or dimension argument could not be evaluated at compile time, 
usually because the value of the argument depends on a parameter to the computation, on a variable, 
or on a stateful operation such as a random number generator.

Stack trace for op definition: 
File "run_pretraining.py", line 204, in <module>
  config = main(start_time)
File "run_pretraining.py", line 184, in main
  trained_model = run_customized_training_loop(
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 675, in run_customized_training_loop
  train_steps_strategy(
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 407, in train_steps_strategy
  if num_grad_accumulates != 1:
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 408, in train_steps_strategy
  for step_idx in tf.range(steps * num_grad_accumulates):
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 410, in train_steps_strategy
  strategy.run(_forward, args=(next(iterator),))
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 324, in _forward
  loss, model_outputs = model(inputs, is_training=True)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2491, in call
  if config.uniform_generator:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2496, in call
  mlm_output = self._get_masked_lm_output(
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2541, in _get_masked_lm_output
  if self._config.uniform_generator:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2550, in _get_masked_lm_output
  outputs = generator(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs
  )
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1880, in call
  outputs = self.deberta(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs
  )
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1617, in call
  encoder_outputs = self.encoder(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 527, in call
  for i, layer_module in enumerate(self.layer):
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 532, in call
  layer_outputs = layer_module(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 317, in call
  attention_outputs = self.attention(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 226, in call
  self_outputs = self.self(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 876, in call
  if self.relative_attention:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 878, in call
  rel_att = self.disentangled_att_bias(
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 991, in disentangled_att_bias
  if "c2p" in self.pos_att_type:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1012, in disentangled_att_bias
  c2p_att = tnp.take_along_axis(

2022-07-21 23:36:18.184105: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at xla_ops.cc:248 : 
INVALID_ARGUMENT: 
Input 0 to node `pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs` 
with op BroadcastArgs must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. 
This error means that a shape or dimension argument could not be evaluated at compile time, 
usually because the value of the argument depends on a parameter to the computation, 
on a variable, or on a stateful operation such as a random number generator.

Stack trace for op definition: 
File "run_pretraining.py", line 204, in <module>
  config = main(start_time)
File "run_pretraining.py", line 184, in main
  trained_model = run_customized_training_loop(
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 675, in run_customized_training_loop
  train_steps_strategy(
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 407, in train_steps_strategy
  if num_grad_accumulates != 1:
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 408, in train_steps_strategy
  for step_idx in tf.range(steps * num_grad_accumulates):
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 410, in train_steps_strategy
  strategy.run(_forward, args=(next(iterator),))
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 324, in _forward
  loss, model_outputs = model(inputs, is_training=True)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2491, in call
  if config.uniform_generator:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2496, in call
  mlm_output = self._get_masked_lm_output(
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2541, in _get_masked_lm_output
  if self._config.uniform_generator:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2550, in _get_masked_lm_output
  outputs = generator(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs
  )
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1880, in call
  outputs = self.deberta(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs
  )
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1617, in call
  encoder_outputs = self.encoder(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 527, in call
  for i, layer_module in enumerate(self.layer):
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 532, in call
  layer_outputs = layer_module(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 317, in call
  attention_outputs = self.attention(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 226, in call
  self_outputs = self.self(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 876, in call
  if self.relative_attention:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 878, in call
  rel_att = self.disentangled_att_bias(
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 991, in disentangled_att_bias
  if "c2p" in self.pos_att_type:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1012, in disentangled_att_bias
  c2p_att = tnp.take_along_axis(

         [[{{node pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs}}]]
Traceback (most recent call last):
  File "run_pretraining.py", line 204, in <module>
    config = main(start_time)
  File "run_pretraining.py", line 184, in main
    trained_model = run_customized_training_loop(
  File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 675, in run_customized_training_loop
    train_steps_strategy(
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Input 0 to node `pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs` 
with op BroadcastArgs must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. 
This error means that a shape or dimension argument could not be evaluated at compile time, 
usually because the value of the argument depends on a parameter to the computation, 
on a variable, or on a stateful operation such as a random number generator.

Stack trace for op definition: 
File "run_pretraining.py", line 204, in <module>
  config = main(start_time)
File "run_pretraining.py", line 184, in main
  trained_model = run_customized_training_loop(
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 675, in run_customized_training_loop
  train_steps_strategy(
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 407, in train_steps_strategy
  if num_grad_accumulates != 1:
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 408, in train_steps_strategy
  for step_idx in tf.range(steps * num_grad_accumulates):
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 410, in train_steps_strategy
  strategy.run(_forward, args=(next(iterator),))
File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 324, in _forward
  loss, model_outputs = model(inputs, is_training=True)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2491, in call
  if config.uniform_generator:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2496, in call
  mlm_output = self._get_masked_lm_output(
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2541, in _get_masked_lm_output
  if self._config.uniform_generator:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2550, in _get_masked_lm_output
  outputs = generator(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs
  )
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1880, in call
  outputs = self.deberta(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs
  )
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1617, in call
  encoder_outputs = self.encoder(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 527, in call
  for i, layer_module in enumerate(self.layer):
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 532, in call
  layer_outputs = layer_module(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 317, in call
  attention_outputs = self.attention(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 226, in call
  self_outputs = self.self(
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
  return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__
  outputs = call_fn(inputs, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
  return fn(*args, **kwargs)
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 876, in call
  if self.relative_attention:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 878, in call
  rel_att = self.disentangled_att_bias(
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 991, in disentangled_att_bias
  if "c2p" in self.pos_att_type:
File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1012, in disentangled_att_bias
  c2p_att = tnp.take_along_axis(

         [[{{node pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs}}]]
         [[while/body/_1/while/StatefulPartitionedCall]] [Op:__inference_train_steps_strategy_177980]

WissamAntoun avatar Jul 22 '22 00:07 WissamAntoun

@WissamAntoun Confirmed reproduction of the issue here. Our TF DeBERTa implementation seems to have issues with XLA - I'm investigating now.

Rocketknight1 avatar Jul 22 '22 12:07 Rocketknight1

@WissamAntoun We have a potential fix - I've confirmed that I can compile microsoft/deberta-v3-small with XLA on my local machine. Can you try installing this branch and let me know if this fixes the problem for you? You can use pip install git+https://github.com/huggingface/transformers.git@deberta-xla-fixes

Rocketknight1 avatar Jul 22 '22 13:07 Rocketknight1

I confirm it works on GPUs with XLA, and I got ~20% improved speedup. I'm still testing now on TPUs, will let you know ASAP

WissamAntoun avatar Jul 22 '22 13:07 WissamAntoun

Weirdly enough TPUs didn't seem to care about the changes 😅 even after we removed all the if branches

WissamAntoun avatar Jul 22 '22 15:07 WissamAntoun

Hmm. Can you check that you don't get the slowdown if you switch the model to another model, like BERT or ELECTRA, while keeping all of the other code the same (especially data loading)? I know the profiling indicates that the GatherV2 is the problem, but I'm a little suspicious!

Rocketknight1 avatar Jul 22 '22 16:07 Rocketknight1

I tried disabling relative_attention in deberta, which makes the model a regular BERT, and the performance improved 40x 😅

WissamAntoun avatar Jul 22 '22 16:07 WissamAntoun

@WissamAntoun So the issue really is in that gather! That's extremely interesting - with the simplified code, it's just a single call to tf.gather, but perhaps the batch_dims argument is not handled elegantly on TPU, or XLA converts it in a way that doesn't run well on TPU.

Is it possible that some kind of memory spill is occurring? Can you try lowering your batch size and increasing steps_per_execution?

If that isn't it, then I have no idea - maybe there's some way to rewrite the gather, but I don't really know what to try!

Rocketknight1 avatar Jul 22 '22 17:07 Rocketknight1

@Rocketknight1 I tried your suggestions without any success, sadly!

Then I tried replacing the whole take_along_axis function with tf.gather(..,...,batch_dims=2) which is equivalent, according to this test I made. GPU still runs fine, TPU still has the same issue 😔.

I also ran out of ideas to try, now I'm just waiting for the TPU gods 😅

View code

#%%
import tensorflow as tf

#%%
x_shape = [32, 128, 512]
indices_shape = [32, 128, 128]
x = tf.random.uniform(shape=x_shape)
indices = tf.random.uniform(shape=indices_shape, minval=1, maxval=128, dtype=tf.int32)
#%%
flat_x = tf.reshape(x, (-1, x_shape[-1]))
print(flat_x.shape)  # (4096, 512)
flat_indices = tf.reshape(indices, (-1, indices_shape[-1]))
print(flat_indices.shape)  # (4096, 128)

#%%
gathered = tf.gather(
    params=flat_x, indices=flat_indices, batch_dims=1, validate_indices=None
)
print(gathered.shape)  # (4096, 128)
gathered_reshaped = tf.reshape(gathered, indices.shape)
print(gathered_reshaped.shape)  # ( 32, 128, 128)

# %%
gathered2 = tf.gather(params=x, indices=indices, batch_dims=2, validate_indices=None)
print(gathered2.shape)  # (32, 128, 128)
# %%
tf.assert_equal(gathered2, gathered_reshaped)  # passes

# %%

WissamAntoun avatar Jul 23 '22 13:07 WissamAntoun

I'm clueless in that case - @patrickvonplaten @sanchit-gandhi do you have any idea why a gather or take_along_axis op which is performant on GPU and compiles with XLA would become a huge bottleneck on TPU?

Rocketknight1 avatar Jul 25 '22 11:07 Rocketknight1

In our JAX BLOOM experiments, we experienced significant improvements in performance by changing how we indexed. Swapping scatter ops for one-host broadcasts, we obtained 3-4x speed-ups in practice. The logic is largely lifted from T5X: https://github.com/google-research/t5x/blob/63d9addf628c6d8c547a407a32095fcb527bb20b/t5x/examples/scalable_t5/layers.py#L280-L284

I wonder if applying similar logic here and swapping the gather op to one-hot indexing might help?

sanchit-gandhi avatar Jul 25 '22 14:07 sanchit-gandhi

DO you mean something to BERT one-hot embeddings ?https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/on_device_embedding.py#L79

WissamAntoun avatar Jul 25 '22 14:07 WissamAntoun

Simply modifying the bottleneck function: https://github.com/huggingface/transformers/blob/f4e172716b91b477ce3cddc9a253094b7121a4b8/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py#L525 To use one_hot encodings as opposed to a gather op. The example you've liked looks like the right idea! Worth a try IMO!

sanchit-gandhi avatar Jul 25 '22 14:07 sanchit-gandhi

I tried this, although I'm not sure if it's the best implementation

def take_along_axis(x, indices):

    one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype) # [B, S, P, D] => [B, 128, 128, 512]
    
    # [B, S, P, D] . [B, S, D, 1] = [B, S, P, 1]
    gathered = tf.squeeze(tf.matmul(one_hot_indices, tf.expand_dims(x, axis=-1)), axis=-1)
    return gathered

It improved the speed from 20 seq/s to 110 seq/s. For reference, regular ELECTRA/BERT got ~800 seq/s.

Now it's the reshape and squeeze operations that are "wasting" time:

image

WissamAntoun avatar Jul 25 '22 18:07 WissamAntoun

@sanchit-gandhi is there a better implementation than mine, without expand_dims or squeeze since these are unfavorable operations on TPUs

WissamAntoun avatar Jul 26 '22 14:07 WissamAntoun

Nice! A 5x speed up is a good start. If we can get another 5x we'll be in business. Thanks for linking the Tensorboard profile! Super helpful in identifying bottlenecks like these 🙏

Interesting to see the expand_dims and squeeze are now accruing large amounts of runtime. I'm not a TF user (it's mainly JAX on TPU for me!), so I'm not up to speed with implementation details, but my impression from the profile is that the shapes are unfavourable for XLA. Perhaps you could have a play around and see whether changing the tensor shapes / choice of TF ops have any effect? It's been the case for me in the past that using tensors of different shape can give big speed-ups. Is there a repo you could reference for XLA optimised TF code? For JAX, we usually look to the T5X repo when deciding on tensor shapes and trying out 'hacks' like these: https://github.com/google-research/t5x/tree/main/t5x

cc @Rocketknight1 who's more up to speed in the TF sphere!

sanchit-gandhi avatar Jul 26 '22 15:07 sanchit-gandhi

Hey @WissamAntoun! Any luck with this? Maybe also worth trying https://www.tensorflow.org/api_docs/python/tf/experimental/numpy/take_along_axis

sanchit-gandhi avatar Jul 28 '22 08:07 sanchit-gandhi

Hey @sanchit-gandhi , I have already tried the exp. numpy function with no improvement at all compared to gather with batch_dims=2.

I also tried going up to sequence length of 512, I got the exact same speedup but it is still much slower than expected (around 20 seq/s for sentence length 512). I also changed batch sizes with no effect at all

WissamAntoun avatar Jul 28 '22 19:07 WissamAntoun

Okay probably worth sticking with the one-hot encoding hack then, seems most promising! I'm not a TF user so can't comment on the exact implementations changes you could make with the expand_dims or squeeze ops. Perhaps @gante could take a look here with his experience using TF and XLA?

sanchit-gandhi avatar Jul 29 '22 09:07 sanchit-gandhi

Now it's the reshape and squeeze operations that are "wasting" time

Interesting -- I spent some time with TPU profiling on a different application (TF text generation with a myriad of models), and found that those two operations were part of the bottleneck (along XLA's dynamic_update_slice). They accounted for 50-70% of the execution time. Do you know if it is also a bottleneck for FLAX, @sanchit-gandhi (e.g. the cache updates here)?

gante avatar Aug 02 '22 13:08 gante

For JAX BLOOM we couldn't even compile the 176B parameter model with the naive implementation of concatenate_to_cache, yet alone benchmark which operations consumed the bulk of the execution time! We swapped it for this more efficient implementation (with one-hot encodings etc): https://github.com/huggingface/bloom-jax-inference/blob/2a04aa519d262729d54adef3d19d63879f81ea89/bloom_inference/modeling_bloom/modeling_bloom.py#L119 Coincidentally, we've just run the JAX profiler for this implementation and are going through the traceback it with some of the Google JAX guys later today. Will report back on how performance fares!

sanchit-gandhi avatar Aug 02 '22 14:08 sanchit-gandhi

def take_along_axis(x, indices):

    one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype) # [B, S, P, D] => [B, 128, 128, 512]
    
    # [B, S, P, D] . [B, S, D, 1] = [B, S, P, 1]
    gathered = tf.squeeze(tf.matmul(one_hot_indices, tf.expand_dims(x, axis=-1)), axis=-1)
    return gathered

@gante Do you think the one-hot trick can be done without the expands_dims and squeeze, maybe then we can just dodge the whole problem

WissamAntoun avatar Aug 02 '22 14:08 WissamAntoun

@sanchit-gandhi that's interesting! I'd be interested in knowing the pro tips for XLA (which should also apply to TF)

@WissamAntoun Yeah, we can rework it with tf.einsum magic, assuming the operation can be rewritten with Einstein notation -- in this case, it is possible! Check the implementation below, give it a try, and let us know if it helped with speed on a TPU (my debug runs confirmed that they are numerically equivalent)

def take_along_axis(x, indices):
    # [B, S, P] -> [B, S, P, D]
    one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)
    
    # if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
    # grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
    gathered = tf.einsum('ijkl,ijl->ijk', one_hot_indices, x)

    return gathered

gante avatar Aug 03 '22 12:08 gante

@gante I tested the tf.einsum implementation. It gave me the same performance as the one_hot trick, which is about ~120 seq/second. I tried it with different batch sizes but still it didn't change much.

This is a screenshot of the profiler: Screenshot 2022-08-03 155826

WissamAntoun avatar Aug 04 '22 07:08 WissamAntoun

I'm out of suggestions :( I suspect this is a good question for Google's XLA and TPU teams -- the problem is probably at a compiler/hardware level.

gante avatar Aug 04 '22 09:08 gante

Yeah this is a weird and unexpected bug. Do you know someone we can get in contact with from Google's XLA or TPU team?

And thanks a lot for the efforts you guys put into this issue!

WissamAntoun avatar Aug 04 '22 14:08 WissamAntoun

@sanchit-gandhi do you know a good point of contact for TPU problems?

gante avatar Aug 04 '22 15:08 gante

Ping @JackCaoG for help :)

stefan-it avatar Aug 04 '22 15:08 stefan-it

Thanks, I will try to take a look or finding someone from my team to help.

nvm, this is tf2, I only knows pt/xla lol

JackCaoG avatar Aug 05 '22 17:08 JackCaoG

@sanchit-gandhi do you know a good point of contact for TPU problems?

Only for JAX on TPU, I'll ask around and see if there is anyone who can help with TF!

sanchit-gandhi avatar Aug 08 '22 17:08 sanchit-gandhi

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Sep 02 '22 15:09 github-actions[bot]