xla
xla copied to clipboard
Custom learning rate scheduler affects TPU performance
❓ Questions and Help
I have trained my transformer model once on a single GPU and once using a multi-core TPU. In both cases a batchsize of 256 is used (times 8 for the TPU). My training results show that the TPU loss after 400 update steps almost equals the GPU loss after 400 updates even though the effective batchsize is 8*times as high. This leads me to believe that the TPU cores are somehow misaligned thus each training their own model (This trend continues).
I use a custom learning rate scheduler to update the LR at each training step, see . If I remove this scheduler the training loss during TPU training drops significantly faster but the training becomes very unstable.
In the training loop the optimizer is initialized for each core and as part of the Scheduler which updates the learning rate before each training step:
def map_fn(index, flags):
torch.manual_seed(flags['seed'])
device = xm.xla_device()
model = WRAPPED_MODEL.to(device).train()
scheduler = Scheduler(Adam(model.parameters(), betas=(0.9,0.98), eps= 10e-9),config) #<-----
loss_fn = torch.nn.CrossEntropyLoss(ignore_index= config.pad_idx)
train_sampler = torch.utils.data.distributed.DistributedSampler(
dataset["train"],
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)
train_loader = torch.utils.data.DataLoader(
dataset["train"],
batch_size=flags['batch_size'],
sampler=train_sampler,
num_workers=flags['num_workers'],
drop_last=True)
def train_epoch(loader):
model.train()
for batch_num, batch in enumerate(loader):
src_input, trgt_seq = batch["input_ids"], batch["labels"]
trgt_input = trgt_seq[:,:-1]
trgt_label = trgt_seq[:,1:]
scheduler.optimizer.zero_grad() #<-----
pred = model(src_input, trgt_input)
loss = loss_fn(pred.transpose(1,2), trgt_label)
loss.backward()
scheduler.update_learning_rate() #<-----
xm.optimizer_step(scheduler.optimizer) #<-----
if batch_num % flags['log_steps'] == 0:
xm.master_print(f'[{batch_num}/ {len(loader)}] Loss={loss} Time={get_time()}')
for epoch in range(flags['num_epochs']):
para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
train_epoch(para_train_loader)
xm.master_print("Finished training epoch {}".format(epoch))
Any help in mitigating the performance difficulties encountered when using the scheduler is more than welcome! Thanks
Extra Information
Model: "Attention Is All You Need Transformer" (self-coded)
Environment: Colab TPUv2 using torch xla 1.12, Colab GPU T4 (non xla torch)
Train settings: Batch size 256, same LR schedule, same loss function (CrossEntropy), TPU uses 8 cores so 8 * 256 batch
Data: WMT14 4.5million sentences de-en
@AlexWertheim do you have cycle to take a look at this one?
Sure, I can take a look!
@DanielRoeder1
Thanks for sharing your model bug with us. Can you please provide the following details so we can reproduce exactly what you experienced on your end. This way we will be able to circle back with concrete steps you can take to improve the problem.
- reference to the entire code if it's more than what you have above
- reference to the commands you ran on both experiments. this should include the hparams you used in the training.
- reference for us to access the data you used to train the model
Thanks.
Sure, excuse the late response. You can find the complete training code in the following colab notebooks:
TPU: https://colab.research.google.com/drive/1fSTCbKq7b2iYaDQwrkVe18E81qDZdt3N?usp=sharing
GPU: https://colab.research.google.com/drive/1hW9_pr4B1yDI9sfMs8DRyGUFYQXkybft?usp=sharing
The hyperparameters are the same between both notebooks. The majority of parameters is set in the config.json