xla icon indicating copy to clipboard operation
xla copied to clipboard

Custom learning rate scheduler affects TPU performance

Open DanielRoeder1 opened this issue 2 years ago • 2 comments

❓ Questions and Help

I have trained my transformer model once on a single GPU and once using a multi-core TPU. In both cases a batchsize of 256 is used (times 8 for the TPU). My training results show that the TPU loss after 400 update steps almost equals the GPU loss after 400 updates even though the effective batchsize is 8*times as high. This leads me to believe that the TPU cores are somehow misaligned thus each training their own model (This trend continues). I use a custom learning rate scheduler to update the LR at each training step, see Train.py, Scheduler. If I remove this scheduler the training loss during TPU training drops significantly faster but the training becomes very unstable.

In the training loop the optimizer is initialized for each core and as part of the Scheduler which updates the learning rate before each training step:

def map_fn(index, flags):
  torch.manual_seed(flags['seed'])
  device = xm.xla_device()  
  model = WRAPPED_MODEL.to(device).train()
  scheduler = Scheduler(Adam(model.parameters(), betas=(0.9,0.98), eps= 10e-9),config) #<-----
  loss_fn = torch.nn.CrossEntropyLoss(ignore_index= config.pad_idx)

  train_sampler = torch.utils.data.distributed.DistributedSampler(
    dataset["train"],
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  train_loader = torch.utils.data.DataLoader(
      dataset["train"],
      batch_size=flags['batch_size'],
      sampler=train_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)

  def train_epoch(loader):
    model.train()
    for batch_num, batch in enumerate(loader):
      src_input, trgt_seq = batch["input_ids"], batch["labels"]
      trgt_input = trgt_seq[:,:-1]
      trgt_label = trgt_seq[:,1:]
      scheduler.optimizer.zero_grad() #<-----
      pred = model(src_input, trgt_input)
      loss = loss_fn(pred.transpose(1,2), trgt_label)
      loss.backward()

      scheduler.update_learning_rate() #<-----
      xm.optimizer_step(scheduler.optimizer) #<-----

      if batch_num % flags['log_steps'] == 0:
        xm.master_print(f'[{batch_num}/ {len(loader)}] Loss={loss} Time={get_time()}')
        
  for epoch in range(flags['num_epochs']):
    para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
    train_epoch(para_train_loader)
    xm.master_print("Finished training epoch {}".format(epoch))

Any help in mitigating the performance difficulties encountered when using the scheduler is more than welcome! Thanks

Extra Information

Model: "Attention Is All You Need Transformer" (self-coded)

Environment: Colab TPUv2 using torch xla 1.12, Colab GPU T4 (non xla torch)

Train settings: Batch size 256, same LR schedule, same loss function (CrossEntropy), TPU uses 8 cores so 8 * 256 batch

Data: WMT14 4.5million sentences de-en

DanielRoeder1 avatar Oct 10 '22 20:10 DanielRoeder1

@AlexWertheim do you have cycle to take a look at this one?

JackCaoG avatar Oct 11 '22 00:10 JackCaoG

Sure, I can take a look!

AlexWertheim avatar Oct 11 '22 16:10 AlexWertheim

@DanielRoeder1

Thanks for sharing your model bug with us. Can you please provide the following details so we can reproduce exactly what you experienced on your end. This way we will be able to circle back with concrete steps you can take to improve the problem.

  • reference to the entire code if it's more than what you have above
  • reference to the commands you ran on both experiments. this should include the hparams you used in the training.
  • reference for us to access the data you used to train the model

Thanks.

miladm avatar Oct 13 '22 23:10 miladm

Sure, excuse the late response. You can find the complete training code in the following colab notebooks:

TPU: https://colab.research.google.com/drive/1fSTCbKq7b2iYaDQwrkVe18E81qDZdt3N?usp=sharing

GPU: https://colab.research.google.com/drive/1hW9_pr4B1yDI9sfMs8DRyGUFYQXkybft?usp=sharing

The hyperparameters are the same between both notebooks. The majority of parameters is set in the config.json

DanielRoeder1 avatar Oct 17 '22 16:10 DanielRoeder1