pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Loss does not decrease on single TPU

Open pjspol opened this issue 3 years ago • 17 comments

https://github.com/pytorch/xla/issues/2735#issue-786787271

🐛 Bug

With pytorch-lightning, the trainer hangs up at the end of epoch. By keyboard interruption, it seems that it hangs up at fd_event_list = self._poll.poll(timeout)

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-4-fd5234884739> in <module>()
      4 autoencoder_1 = LitAutoEncoder()
      5 trainer_1 = pl.Trainer(tpu_cores=8, max_epochs=1, progress_bar_refresh_rate=20)
----> 6 trainer_1.fit(autoencoder_1, DataLoader(dataset_mnist_train))

6 frames
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    471         self.call_hook('on_fit_start')
    472 
--> 473         results = self.accelerator_backend.train()
    474         self.accelerator_backend.teardown()
    475 

/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/tpu_accelerator.py in train(self)
    111                 args=(model, self.trainer, self.mp_queue),
    112                 nprocs=self.trainer.tpu_cores,
--> 113                 start_method=self.start_method
    114             )
    115 

/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py in spawn(fn, args, nprocs, join, daemon, start_method)
    393         join=join,
    394         daemon=daemon,
--> 395         start_method=start_method)
    396 
    397 

/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
    155 
    156     # Loop on join until it returns True or raises an exception.
--> 157     while not context.join():
    158         pass
    159 

/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
     75         ready = multiprocessing.connection.wait(
     76             self.sentinels.keys(),
---> 77             timeout=timeout,
     78         )
     79 

/usr/lib/python3.6/multiprocessing/connection.py in wait(object_list, timeout)
    909 
    910             while True:
--> 911                 ready = selector.select(timeout)
    912                 if ready:
    913                     return [key.fileobj for (key, events) in ready]

/usr/lib/python3.6/selectors.py in select(self, timeout)
    374             ready = []
    375             try:
--> 376                 fd_event_list = self._poll.poll(timeout)
    377             except InterruptedError:
    378                 return ready

KeyboardInterrupt: 

To Reproduce

Steps to reproduce the behavior:

  1. Run this Colab (or another Colab for issue reporting) with TPU.
  2. Wait until the first (and the last) epoch finishes. (all steps are completed but it keeps running)
  3. Interrupt kernel

Expected behavior

The training should be completed immediately after all steps are completed.

Environment

  • Reproducible on XLA backend [CPU/TPU]: TPU
  • torch_xla version: 1.7 and nightly

cc @kaushikb11 @rohitgr7 @akihironitta

pjspol avatar Aug 30 '22 17:08 pjspol

@JackCaoG @kaushikb11 I'm reposting this to pytorch-lightning issues as you asked.

pjspol avatar Aug 30 '22 17:08 pjspol

Can you try upgrading dependencies?

I believe we test with torch_xla==1.12. The stacktrace also shows you are using python==3.6 but 3.7 is our minimum supported version. Finally, have you tried with the latest pytorch_lightning version?

carmocca avatar Aug 31 '22 17:08 carmocca

@carmocca I have another notebook where the same issue occurred while using torch_xla==1.12 and python==3.7. However, because of other dependencies, it makes use of pytorch_lightning==1.0.6. Using the latest pytorch_lightning runs into other problems. Is there a way to resolve this issue with something close to pytorch_lightning==1.0.6?

pjspol avatar Aug 31 '22 17:08 pjspol

Unfortunately, we cannot backport a bugfix for an old release. It would need to be reproducible in the latest version so that we can look into it.

carmocca avatar Aug 31 '22 23:08 carmocca

@carmocca Upgrading to pytorch_lightning==1.2.9 fixed the issue. However, I've noticed that if I train with GPU, the loss converges appropriately, but with TPU, the loss stays exactly the same whether there is 1 epoch or 500 epochs, yet there are no error messages. I dug around on the web for answers, but what I found is solutions for using torch_xla without pytorch_lightning. Any chance you might know what's going on?

pjspol avatar Sep 01 '22 02:09 pjspol

Sorry, I don't. What I would do in your situation is to sequentially upgrade each patch version to update the code while minimizing breaks (look at the warning messages!). 1.2.* -> 1.3.* -> 1.4.* -> ... -> 1.7.*. Then see if things are working as expected or not.

carmocca avatar Sep 01 '22 10:09 carmocca

@carmocca I upgraded to pytorch_lightning==1.7.4 and only had to adjust a few things, but I'm now stuck on this error: pytorch_lightning.utilities.exceptions.MisconfigurationException: Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged. You can fix this by setting an attribute for the metric in your `LightningModule`. I'm not sure what exactly I should do to fix this. Any help would be much appreciated!

pjspol avatar Sep 01 '22 17:09 pjspol

Can you share your LightningModule code?

carmocca avatar Sep 01 '22 17:09 carmocca

@carmocca Here it is:

class TrainingModule(pl.LightningModule):
    def __init__(self,
                 model: PretrainedModelBase, *,
                 loss_fn: Callable[[Tensor, Tensor], Tensor],
                 optimizer: torch.optim.Optimizer,
                 metric_cls: Type[Metric]):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer

        self.weighted_loss = {phase: BatchWeightedLoss() for phase in ['train', 'valid', 'test']}
        self.metric = {phase: metric_cls() for phase in ['train', 'valid', 'test']}
        self.metric_name = metric_cls.__name__.lower()

    def cuda(self, device: Optional[int] = None):
        setattr(self.model, 'device', device)
        return super(TrainingModule, self).cuda(device)

    def forward(self, batch: BatchEncodingProtocol):
        return self.model.forward(batch)

    def _step(self, mode: str, batch: BatchEncodingProtocol, batch_idx: int) -> torch.Tensor:
        output = self.forward(batch)
        loss = self.loss_fn(output, batch.y)
        preds = torch.mean(torch.stack(output), dim=0) if isinstance(output, tuple) else output

        self.weighted_loss[mode](loss, len(batch))
        self.metric[mode](preds.cpu(), batch.y.cpu())

        self.log(f'{mode}_loss', self.weighted_loss[mode], on_epoch=True, on_step=False)
        self.log(f'{mode}_{self.metric_name}', self.metric[mode], on_epoch=True, on_step=False)

        return loss

    def on_train_epoch_start(self) -> None:
        self.weighted_loss['train'].reset()
        self.metric['train'].reset()

    def training_step(self, batch: BatchEncodingProtocol, batch_idx: int) -> torch.Tensor:
        return self._step('train', batch, batch_idx)

    def validation_step(self, batch: BatchEncodingProtocol, batch_idx: int) -> torch.Tensor:
        return self._step('valid', batch, batch_idx)

    def test_step(self, batch: BatchEncodingProtocol, batch_idx: int) -> torch.Tensor:
        return self._step('test', batch, batch_idx)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return self.optimizer

pjspol avatar Sep 01 '22 19:09 pjspol

You need to define the metrics with a ModuleDict: https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metrics-and-devices

carmocca avatar Sep 02 '22 11:09 carmocca

@carmocca I've done that now and the training proceeds but the same thing is happening: the loss remains roughly the same no matter the number of epochs when training on TPU. I understand conceptually what is happening but I don't know how to fix the code. Any help would be deeply appreciated!

pjspol avatar Sep 02 '22 19:09 pjspol

Can you share the colab link now? Does it not happen on CPU or GPU?

carmocca avatar Sep 02 '22 20:09 carmocca

@carmocca It does not happen on GPU. Here is the notebook link: https://colab.research.google.com/drive/1-igDGsR4b0yZ31VaBeAT-pAoR2Y-FdBO?usp=sharing Attached is a corresponding dataset. RMAT_TPU_test_dataset.csv

pjspol avatar Sep 03 '22 04:09 pjspol

Any ideas @kaushikb11? A bug impacting optimization would be of high priority.

carmocca avatar Sep 05 '22 10:09 carmocca

Thanks @carmocca!

Hi @pjspol! Is this issue also occurring with single TPU core?

kaushikb11 avatar Sep 05 '22 11:09 kaushikb11

@kaushikb11 Yes, the issue also occurs with a single TPU core.

pjspol avatar Sep 05 '22 15:09 pjspol

@carmocca @kaushikb11 Any updates on this issue?

pjspol avatar Sep 16 '22 21:09 pjspol

@carmocca @kaushikb11 @Borda I should mention that I tried the same thing in Kaggle (slightly adjusted from the Google Colab version), and the loss also stayed the same, with the following error showing up: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/conda/lib. I'm not really sure if this is relevant, but I figured it can't hurt to share.

pjspol avatar Nov 17 '22 23:11 pjspol

@carmocca @kaushikb11 @Borda Any updates on this issue?

pjspol avatar Mar 17 '23 17:03 pjspol

Hi @pjspol! What's the latest reproduction snippet for this bug? The link in https://github.com/Lightning-AI/lightning/issues/14457#issuecomment-1236043241 is private now.

This issue needs careful narrowing to find the underlying problem. There were several improvements to the XLA internals by @Liyang90 released on the latest versions. So I would try reproducing it again.

carmocca avatar Mar 17 '23 17:03 carmocca