keras-io
keras-io copied to clipboard
The distributed training example fails to mention batch and LR scaling
Keras.io example: https://keras.io/examples/nlp/data_parallel_training_with_keras_nlp/ Merged PR: https://github.com/keras-team/keras-io/pull/1395
This example is good on the whole but it would be much better with proper batch size and learning rate scaling. Without this, using two accelerators instead of one will not train any faster.
The usual scaling is:
batch_size = strategy.num_replicas_in_sync * sigle_worker_batch_size
The large global batch is processes on the multiple accelerators in chunks, one chunk per accelerator. Without increasing the batch size, you are sending smaller per-worker batches to the accelerators, potentially under-utilizing them.
lr = strategy.num_replicas_in_sync * single_worker_lr
Bigger batches also means fewer gradient updates per epoch. Without scaling the LR, the model will be learning more slowly on multiple workers. Gradient updates computed on bigger batches need to be allowed to do "more training work", through a higher learning rate.
Of course, these are just rules of thumb. Actual optimal values can only be obtained by careful hyper-parameter tuning, with both raw speed, and time to convergence metrics.