DeepLearningExamples icon indicating copy to clipboard operation
DeepLearningExamples copied to clipboard

[BERT/TF2] Possible extra optimization step in bert pretraining

Open WissamAntoun opened this issue 2 years ago • 2 comments

Related to BERT/TF@

Describe the bug https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/BERT/official/modeling/model_training_utils.py#L342 and line 357

if num_grad_accumulates != 1:
  for _ in tf.range(steps*num_grad_accumulates):
    strategy.experimental_run_v2(forward, args=(next(iterator),))
    if _ == 0 or (_ + 1) % num_grad_accumulates == 0:
      strategy.experimental_run_v2(step, args=(num_grad_accumulates,))
else:
  for _ in tf.range(steps):
    strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))

The optimizer will run a step when the first batch finishes and will ignore the gradient accumulator. This wouldn't have probably been an issue if the +1 wasn't there since this would be accumulating the gradients from the last loop. so technically if steps was equal to 100 the optimizer.iterations will increase by 101 while current_step will increase by 100. This causes a mismatch when the training is resumed since the current_step is initialized to optimizer.iterations.

Hence, I'm not sure about the purpose of _==0 here.

To Reproduce just run the official pertaining script and check the value of optimizer.iterations and current_step after few training steps.

Environment Please provide at least: Any

Who can help: @nv-kkudrynski

WissamAntoun avatar Aug 02 '22 14:08 WissamAntoun

The purpose of running a step on first batch is for Horovod. According to Horovod's example, we need to broadcast after first batch to initialize variables and the broadcast need to done after first gradient step. As you mentioned optimizer.iterations and current_step may be mismatch in this case. We will check and fix this issue.

meatybobby avatar Aug 03 '22 20:08 meatybobby

I think the best fix would be to do a single step at first when first_batch is true. I'll write a pull-request asap with my suggested fix.

WissamAntoun avatar Aug 04 '22 07:08 WissamAntoun