Loggers fails to create metrics.csv file when running on multiple TPU cores
Bug description
Running the mnist-tutorial from Lightning-AI doesn't create a metrics.csv file when run on a v2-8 Cloud TPU using all 8 cores. This issue reproduces even after killing all running python processes and restarting the python3 kernel on Jupyter.
When setting devices=1 so that the model trains on a single core, the metrics.csv seems to always get created. Reproduced this issue on the latest stable build (2.1.0), as well as nightly (2.2.0dev).
What version are you seeing the problem on?
master
How to reproduce the bug
import lightning as L
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
# from torchmetrics.functional import accuracy
from torchvision import transforms
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
BATCH_SIZE = 1024
class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
self.dims = (1, 28, 28)
self.num_classes = 10
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
class LitModel(L.LightningModule):
def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
super().__init__()
self.num_classes = num_classes # Needed to calculate metrics in val step
self.save_hyperparameters()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_classes),
)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
# self.logger.log_metrics("train_loss", loss, step=batch_idx)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
# acc = accuracy(preds, y, task='multiclass', num_classes=self.num_classes)
self.log("val_loss", loss, on_epoch=True, prog_bar=True)
# self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
# Init DataModule
dm_2 = MNISTDataModule()
# Init model from datamodule's attributes
model_2 = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = L.Trainer(
max_epochs=3,
accelerator="tpu",
devices=8,
)
# Train
print(f"Running on {len(trainer.device_ids)} devices.")
print(f"Logging metrics under: {trainer.logger.log_dir}...")
trainer.fit(model_2, dm_2)
from pathlib import Path
import pandas as pd
csv_path = Path(trainer.logger.log_dir) / 'metrics.csv'
pd.read_csv(csv_path)
Error messages and logs
Running on 8 devices.
Logging metrics under: /home/carlos.gaitan/bigrna-torch/lightning_logs/version_11...
...
FileNotFoundError: [Errno 2] No such file or directory: '/home/carlos.gaitan/bigrna-torch/lightning_logs/version_11/metrics.csv'
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: 11.7
- Lightning:
- lightning: 2.2.0.dev0
- lightning-utilities: 0.9.0
- pytorch-lightning: 2.1.0
- torch: 2.0.0
- torch-xla: 2.0
- torchmetrics: 1.2.0
- torchvision: 0.15.1
- Packages:
- absl-py: 1.4.0
- aiohttp: 3.8.6
- aiosignal: 1.3.1
- anyio: 4.0.0
- appdirs: 1.4.4
- argon2-cffi: 23.1.0
- argon2-cffi-bindings: 21.2.0
- asttokens: 2.4.1
- async-lru: 2.0.4
- async-timeout: 4.0.3
- attrs: 23.1.0
- automat: 0.8.0
- babel: 2.13.1
- backcall: 0.2.0
- beautifulsoup4: 4.12.2
- bleach: 6.1.0
- blinker: 1.4
- cachetools: 5.3.0
- certifi: 2019.11.28
- cffi: 1.16.0
- chardet: 3.0.4
- charset-normalizer: 2.0.12
- click: 8.1.7
- cloud-init: 23.1.2
- cloud-tpu-client: 0.10
- cmake: 3.26.0
- colorama: 0.4.3
- comm: 0.2.0
- command-not-found: 0.3
- configobj: 5.0.6
- constantly: 15.1.0
- cryptography: 2.8
- cython: 0.29.14
- dbus-python: 1.2.16
- debugpy: 1.8.0
- decorator: 5.1.1
- defusedxml: 0.7.1
- distlib: 0.3.4
- distro: 1.4.0
- distro-info: 0.23ubuntu1
- docker-pycreds: 0.4.0
- einops: 0.7.0
- entrypoints: 0.3
- exceptiongroup: 1.1.3
- executing: 2.0.1
- fastjsonschema: 2.19.0
- filelock: 3.7.1
- frozenlist: 1.4.0
- fsspec: 2023.10.0
- gitdb: 4.0.11
- gitpython: 3.1.40
- google-api-core: 1.34.0
- google-api-python-client: 1.8.0
- google-auth: 2.23.4
- google-auth-httplib2: 0.1.0
- google-cloud-core: 2.3.3
- google-cloud-storage: 2.13.0
- google-crc32c: 1.5.0
- google-resumable-media: 2.6.0
- googleapis-common-protos: 1.58.0
- httplib2: 0.14.0
- hyperlink: 19.0.0
- idna: 2.8
- importlib-metadata: 6.8.0
- importlib-resources: 6.1.1
- incremental: 16.10.1
- intel-openmp: 2022.1.0
- ipykernel: 6.26.0
- ipython: 8.12.3
- ipywidgets: 8.1.1
- jedi: 0.19.1
- jinja2: 3.1.2
- json5: 0.9.14
- jsonpatch: 1.22
- jsonpointer: 2.0
- jsonschema: 4.19.2
- jsonschema-specifications: 2023.11.1
- jupyter: 1.0.0
- jupyter-client: 8.6.0
- jupyter-console: 6.6.3
- jupyter-core: 5.5.0
- jupyter-events: 0.9.0
- jupyter-lsp: 2.2.0
- jupyter-server: 2.10.1
- jupyter-server-terminals: 0.4.4
- jupyterlab: 4.0.8
- jupyterlab-pygments: 0.2.2
- jupyterlab-server: 2.25.1
- jupyterlab-widgets: 3.0.9
- keyring: 18.0.1
- language-selector: 0.1
- launchpadlib: 1.10.13
- lazr.restfulclient: 0.14.2
- lazr.uri: 1.0.3
- libtpu-nightly: 0.1.dev20230213
- lightning: 2.2.0.dev0
- lightning-utilities: 0.9.0
- lit: 15.0.7
- markdown-it-py: 3.0.0
- markupsafe: 2.1.3
- matplotlib-inline: 0.1.6
- mdurl: 0.1.2
- memray: 1.10.0
- mistune: 3.0.2
- mkl: 2022.1.0
- mkl-include: 2022.1.0
- more-itertools: 4.2.0
- mpmath: 1.3.0
- multidict: 6.0.4
- nbclient: 0.9.0
- nbconvert: 7.11.0
- nbformat: 5.9.2
- nest-asyncio: 1.5.8
- netifaces: 0.10.4
- networkx: 3.0
- notebook: 7.0.6
- notebook-shim: 0.2.3
- numpy: 1.24.2
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cuda-cupti-cu11: 11.7.101
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-curand-cu11: 10.2.10.91
- nvidia-cusolver-cu11: 11.4.0.1
- nvidia-cusparse-cu11: 11.7.4.91
- nvidia-nccl-cu11: 2.14.3
- nvidia-nvtx-cu11: 11.7.91
- oauth2client: 4.1.3
- oauthlib: 3.1.0
- overrides: 7.4.0
- packaging: 20.3
- pandas: 2.0.3
- pandocfilters: 1.5.0
- parso: 0.8.3
- pathtools: 0.1.2
- pexpect: 4.6.0
- pickleshare: 0.7.5
- pillow: 9.4.0
- pip: 20.0.2
- pkgutil-resolve-name: 1.3.10
- platformdirs: 2.5.2
- prometheus-client: 0.18.0
- prompt-toolkit: 3.0.41
- protobuf: 3.20.3
- psutil: 5.9.6
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- pyasn1: 0.4.2
- pyasn1-modules: 0.2.1
- pycparser: 2.21
- pydantic: 1.10.13
- pydantic-cli: 4.3.0
- pygments: 2.16.1
- pygobject: 3.36.0
- pyhamcrest: 1.9.0
- pyjwt: 1.7.1
- pymacaroons: 0.13.0
- pynacl: 1.3.0
- pyopenssl: 19.0.0
- pyparsing: 2.4.6
- pyrsistent: 0.15.5
- pyserial: 3.4
- python-apt: 2.0.0+ubuntu0.20.4.7
- python-dateutil: 2.8.2
- python-debian: 0.1.36ubuntu1
- python-json-logger: 2.0.7
- pytorch-lightning: 2.1.0
- pytz: 2023.3.post1
- pyyaml: 5.4.1
- pyzmq: 25.1.1
- qtconsole: 5.5.1
- qtpy: 2.4.1
- referencing: 0.31.0
- requests: 2.31.0
- requests-unixsocket: 0.2.0
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rich: 13.6.0
- rpds-py: 0.12.0
- rsa: 4.9
- scipy: 1.10.1
- secretstorage: 2.3.1
- send2trash: 1.8.2
- sentry-sdk: 1.34.0
- service-identity: 18.1.0
- setproctitle: 1.3.3
- setuptools: 62.3.2
- simplejson: 3.16.0
- six: 1.14.0
- smmap: 5.0.1
- sniffio: 1.3.0
- sos: 4.3
- soupsieve: 2.5
- ssh-import-id: 5.10
- stack-data: 0.6.3
- sympy: 1.11.1
- systemd-python: 234
- tbb: 2021.6.0
- terminado: 0.18.0
- tinycss2: 1.2.1
- tomli: 2.0.1
- torch: 2.0.0
- torch-xla: 2.0
- torchmetrics: 1.2.0
- torchvision: 0.15.1
- tornado: 6.3.3
- tqdm: 4.66.1
- traitlets: 5.13.0
- triton: 2.0.0
- twisted: 18.9.0
- typing-extensions: 4.5.0
- tzdata: 2023.3
- ubuntu-advantage-tools: 27.8
- ufw: 0.36
- unattended-upgrades: 0.1
- uritemplate: 3.0.1
- urllib3: 1.25.8
- virtualenv: 20.14.1
- wadllib: 1.3.3
- wandb: 0.15.12
- wcwidth: 0.2.10
- webencodings: 0.5.1
- websocket-client: 1.6.4
- wheel: 0.34.2
- widgetsnbextension: 4.0.9
- yarl: 1.9.2
- zipp: 1.0.0
- zope.interface: 4.7.1
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.10
- release: 5.13.0-1027-gcp
- version: #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022
More info
The CSV file sometimes shows up after several minutes (about an hour?), but more often than not it doesn't. It looks like the CSVLogger does not always materialize the logged metrics when running in a distributed setting on TPUs. Sometimes the version_X folder is not created at all, and thus the hparams.yaml file containing metadata about the run is not written to disk.
cc @carmocca @JackCaoG @Liyang90 @gkroiz
Replacing CSVLogger() with WandbLogger() also works on a single TPU core and crashes when run on all 8 cores. The issue(s) might not be specific to CSVLogger.
Changes:
from lightning import pytorch as pl
logger = pl.loggers.WandbLogger()
dm = MNISTDataModule()
model = LitModel(*dm.dims, dm.num_classes)
trainer = L.Trainer(
max_epochs=3,
accelerator="tpu",
devices=8,
logger=logger,
)
print(f"Running on {len(trainer.device_ids)} devices.")
print(f"Logging metrics under: {trainer.logger.log_dir}...")
trainer.fit(model, dm)
Call stack:
WARNING:root:XRT configuration not detected. Defaulting to preview PJRT runtime. To silence this warning and continue using PJRT, explicitly set PJRT_DEVICE to a supported device or configure XRT. To disable default device selection, set PJRT_SELECT_DEFAULT_DEVICE=0
WARNING:root:For more information about the status of PJRT, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running on 8 devices.
Logging metrics under: None...
wandb: Currently logged in as: carlos-gaitan (deep-genomics-ml). Use `wandb login --relogin` to force relogin
Tracking run with wandb version 0.16.0
Run data is saved locally in ./wandb/run-20231121_194518-zw18chan
Syncing run [expert-thunder-3](https://wandb.ai/deep-genomics-ml/lightning_logs/runs/zw18chan) to [Weights & Biases](https://wandb.ai/deep-genomics-ml/lightning_logs) ([docs](https://wandb.me/run))
View project at https://wandb.ai/deep-genomics-ml/lightning_logs
View run at https://wandb.ai/deep-genomics-ml/lightning_logs/runs/zw18chan
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'zw18chan'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'zw18chan'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'zw18chan'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'zw18chan'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'zw18chan'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'zw18chan'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'zw18chan'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 111, in run
shandler(sreq)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 150, in server_inform_attach
self._mux._streams[stream_id]._settings._proto,
KeyError: 'zw18chan'
---------------------------------------------------------------------------
_RemoteTraceback Traceback (most recent call last)
_RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
return [fn(*args) for args in chunk]
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
replica_results = list(
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
return fn()
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 341, in __call__
self.fn(global_ordinal(), *self.args, **self.kwargs)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 142, in _wrapping_function
trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 210, in _deepcopy_tuple
y = [deepcopy(a, memo) for a in x]
File "/usr/lib/python3.8/copy.py", line 210, in <listcomp>
y = [deepcopy(a, memo) for a in x]
File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 205, in _deepcopy_list
append(deepcopy(a, memo))
File "/usr/lib/python3.8/copy.py", line 161, in deepcopy
rv = reductor(4)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 353, in __getstate__
_ = self.experiment
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/lightning/fabric/loggers/logger.py", line 118, in experiment
return fn(self)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 397, in experiment
self._experiment = wandb._attach(attach_id)
File "/home/carlos.gaitan/.local/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 888, in _attach
raise UsageError(f"Unable to attach to run {attach_id}")
wandb.errors.UsageError: Unable to attach to run zw18chan
yes I can confirm the multi-gpu training will cause the logger IO issue some now and then
File "/opt/conda/lib/python3.10/site-packages/wandb/sdk/internal/sender.py", line 1174, in _update_summary
with open(summary_path, "w") as f:
FileNotFoundError: [Errno 2] No such file or directory: 'logs/custom_lfq_2024-01-15T06-28/wandb/run-20240115_062902-custom_lfq_2024-01-15T06-28/files/wandb-summary.json'
wandb: ERROR Internal wandb error: file data was not synced
This problem is still in the latest version of Lightning (2.2.4) and on tpu v4-8