pytorch-lightning
pytorch-lightning copied to clipboard
WandbLogger crashes when used with TPU VM
Bug description
On a TPU VM, using WandbLogger
causes training to crash. I am using the nightly build which I know states "no guarantees", so apologies in advance if this is currently being worked on (I wasn't able to find any relevant issues or PRs). I am also unsure of why this error is occurring, and whether it is an issue with Lightning or WandB.
What version are you seeing the problem on?
master
How to reproduce the bug
import lightning.pytorch as pl
import lightning.pytorch.loggers
import torch
import torch.backends.cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
class LinearDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
w = torch.randn([128])
eps = 0.01 * torch.randn([2400, 1])
self.X = torch.randn([2400, 128])
self.Y = torch.sum(w * self.X, dim=-1, keepdim=True) + eps
def loader(self):
return DataLoader(
TensorDataset(self.X, self.Y),
batch_size=100,
num_workers=4,
shuffle=True,
)
def train_dataloader(self):
return self.loader()
def val_dataloader(self):
return self.loader()
class LinearRegression(pl.LightningModule):
def __init__(self):
super().__init__()
self.proj = nn.Linear(128, 1)
def step(self, batch, split):
X, y = batch
loss = F.mse_loss(self.proj(X), y)
self.log(f"{split}/loss", loss, sync_dist=(split != "train"), prog_bar=True)
return loss
def training_step(self, batch, batch_idx):
return self.step(batch, "train")
def validation_step(self, batch, batch_idx):
return self.step(batch, "val")
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)
def train():
pl.seed_everything(100, workers=True)
data = LinearDataModule()
model = LinearRegression()
trainer = pl.Trainer(
accelerator="tpu",
devices=8,
enable_checkpointing=False,
precision="bf16-mixed",
logger=pl.loggers.WandbLogger(project="tpu_debug"),
max_epochs=100,
enable_progress_bar=True,
)
trainer.fit(model=model, datamodule=data)
if __name__ == "__main__":
train()
The above code was written to a file train.py
and run with
PJRT_DEVICE=TPU python3 -m train
Error messages and logs
Global seed set to 100
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: .... Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.5
wandb: Run data is saved locally in ./wandb/run-20230710_212512-3vghnzj8
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run fine-music-19
wandb: ⭐️ View project at https://wandb.ai/...
wandb: 🚀 View run at https://wandb.ai/.../runs/3vghnzj8
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
wandb: ERROR Unable to attach to run 3vghnzj8
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
return [fn(*args) for args in chunk]
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
replica_results = list(
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
return fn()
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 341, in __call__
self.fn(global_ordinal(), *self.args, **self.kwargs)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 128, in _wrapping_function
trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 210, in _deepcopy_tuple
y = [deepcopy(a, memo) for a in x]
File "/usr/lib/python3.8/copy.py", line 210, in <listcomp>
y = [deepcopy(a, memo) for a in x]
File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 205, in _deepcopy_list
append(deepcopy(a, memo))
File "/usr/lib/python3.8/copy.py", line 161, in deepcopy
rv = reductor(4)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 356, in __getstate__
_ = self.experiment
File "/.../.local/lib/python3.8/site-packages/lightning/fabric/loggers/logger.py", line 114, in experiment
return fn(self)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 398, in experiment
self._experiment = wandb._attach(attach_id)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 877, in _attach
raise UsageError(f"Unable to attach to run {attach_id}")
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/...", line 76, in <module>
train()
File "/...", line 72, in train
trainer.fit(model=model, datamodule=data)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 536, in fit
call._call_and_handle_interrupt(
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 41, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 88, in launch
process_context = xmp.spawn(
File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
_run_multiprocess(spawn_fn, start_method=start_method)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
replica_results = list(
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
itertools.chain.from_iterable(
File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
return [fn(*args) for args in chunk]
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
replica_results = list(
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
return fn()
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 341, in __call__
self.fn(global_ordinal(), *self.args, **self.kwargs)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 128, in _wrapping_function
trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 210, in _deepcopy_tuple
y = [deepcopy(a, memo) for a in x]
File "/usr/lib/python3.8/copy.py", line 210, in <listcomp>
y = [deepcopy(a, memo) for a in x]
File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 205, in _deepcopy_list
append(deepcopy(a, memo))
File "/usr/lib/python3.8/copy.py", line 161, in deepcopy
rv = reductor(4)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 356, in __getstate__
_ = self.experiment
File "/.../.local/lib/python3.8/site-packages/lightning/fabric/loggers/logger.py", line 114, in experiment
return fn(self)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 398, in experiment
self._experiment = wandb._attach(attach_id)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 877, in _attach
raise UsageError(f"Unable to attach to run {attach_id}")
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/...", line 76, in <module>
train()
File "/...", line 72, in train
trainer.fit(model=model, datamodule=data)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 536, in fit
call._call_and_handle_interrupt(
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 41, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 88, in launch
process_context = xmp.spawn(
File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
_run_multiprocess(spawn_fn, start_method=start_method)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
replica_results = list(
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
itertools.chain.from_iterable(
File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
Environment
- TPU type:
v3-8
- TPU software version:
tpu-vm-pt-2.0
- Packages: other than the packages that come shipped with the TPU VM, I installed Lightning and WandB. A list of precise versions is listed:
Package Version
------------------------ --------------------
absl-py 1.4.0
anyio 3.7.1
appdirs 1.4.4
arrow 1.2.3
attrs 19.3.0
Automat 0.8.0
backoff 2.2.1
beautifulsoup4 4.12.2
blessed 1.20.0
blinker 1.4
cachetools 5.3.0
certifi 2019.11.28
chardet 3.0.4
charset-normalizer 2.0.12
click 8.1.4
cloud-init 22.1
cloud-tpu-client 0.10
cmake 3.26.0
colorama 0.4.3
command-not-found 0.3
configobj 5.0.6
constantly 15.1.0
croniter 1.4.1
cryptography 2.8
Cython 0.29.14
dateutils 0.6.12
dbus-python 1.2.16
deepdiff 6.3.1
distlib 0.3.4
distro 1.4.0
distro-info 0.23ubuntu1
docker-pycreds 0.4.0
entrypoints 0.3
exceptiongroup 1.1.2
fastapi 0.100.0
filelock 3.7.1
fsspec 2023.6.0
gitdb 4.0.10
GitPython 3.1.32
google-api-core 1.34.0
google-api-python-client 1.8.0
google-auth 2.16.2
google-auth-httplib2 0.1.0
googleapis-common-protos 1.58.0
h11 0.14.0
httplib2 0.14.0
hyperlink 19.0.0
idna 2.8
importlib-metadata 1.5.0
incremental 16.10.1
inquirer 3.1.3
intel-openmp 2022.1.0
itsdangerous 2.1.2
Jinja2 2.10.1
jsonpatch 1.22
jsonpointer 2.0
jsonschema 3.2.0
keyring 18.0.1
language-selector 0.1
launchpadlib 1.10.13
lazr.restfulclient 0.14.2
lazr.uri 1.0.3
libtpu-nightly 0.1.dev20230213
lightning 2.1.0.dev0
lightning-cloud 0.5.37
lightning-utilities 0.9.0
lit 15.0.7
markdown-it-py 3.0.0
MarkupSafe 1.1.0
mdurl 0.1.2
mkl 2022.1.0
mkl-include 2022.1.0
more-itertools 4.2.0
mpmath 1.3.0
netifaces 0.10.4
networkx 3.0
numpy 1.24.2
nvidia-cublas-cu11 11.10.3.66
nvidia-cuda-cupti-cu11 11.7.101
nvidia-cuda-nvrtc-cu11 11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11 8.5.0.96
nvidia-cufft-cu11 10.9.0.58
nvidia-curand-cu11 10.2.10.91
nvidia-cusolver-cu11 11.4.0.1
nvidia-cusparse-cu11 11.7.4.91
nvidia-nccl-cu11 2.14.3
nvidia-nvtx-cu11 11.7.91
oauth2client 4.1.3
oauthlib 3.1.0
ordered-set 4.1.0
packaging 20.3
pathtools 0.1.2
pexpect 4.6.0
Pillow 9.4.0
pip 20.0.2
platformdirs 2.5.2
protobuf 3.20.3
psutil 5.9.5
pyasn1 0.4.2
pyasn1-modules 0.2.1
pydantic 1.10.11
Pygments 2.15.1
PyGObject 3.36.0
PyHamcrest 1.9.0
PyJWT 1.7.1
pymacaroons 0.13.0
PyNaCl 1.3.0
pyOpenSSL 19.0.0
pyparsing 2.4.6
pyrsistent 0.15.5
pyserial 3.4
python-apt 2.0.0+ubuntu0.20.4.7
python-dateutil 2.8.2
python-debian 0.1.36ubuntu1
python-editor 1.0.4
python-multipart 0.0.6
pytorch-lightning 2.0.5
pytz 2023.3
PyYAML 5.4.1
readchar 4.0.5
requests 2.27.1
requests-unixsocket 0.2.0
rich 13.4.2
rsa 4.9
SecretStorage 2.3.1
sentry-sdk 1.28.0
service-identity 18.1.0
setproctitle 1.3.2
setuptools 62.3.2
simplejson 3.16.0
six 1.14.0
smmap 5.0.0
sniffio 1.3.0
sos 4.3
soupsieve 2.4.1
ssh-import-id 5.10
starlette 0.27.0
starsessions 1.3.0
sympy 1.11.1
systemd-python 234
tbb 2021.6.0
torch 2.0.0
torch-xla 2.0
torchmetrics 1.0.0
torchvision 0.15.1
tqdm 4.65.0
traitlets 5.9.0
triton 2.0.0
Twisted 18.9.0
typing-extensions 4.5.0
ubuntu-advantage-tools 27.8
ufw 0.36
unattended-upgrades 0.1
uritemplate 3.0.1
urllib3 1.26.16
uvicorn 0.22.0
virtualenv 20.14.1
wadllib 1.3.3
wandb 0.15.5
wcwidth 0.2.6
websocket-client 1.6.1
websockets 11.0.3
wheel 0.34.2
zipp 1.0.0
zope.interface 4.7.1
More info
If I train without a logger instead, then no error occurs and the script proceeds normally.
cc @carmocca @JackCaoG @steventk-g @Liyang90 @awaelchli @morganmcg1 @borisdayma @scottire @parambharat
Very likely caused by #17818. I'm seeing this with multi-gpu as well and it's likely not TPU related.
@rejuvyesh Why do you think that it is very likely caused by #17818? Can you git-bisect or provide me with a code example for multi-gpu? I appreciate the help.
EDIT: I ran the above code example with accelerator="cuda"
and couldn't see any issues.
@alstonlo Thanks for the report. I don't see anything wrong with the code example. My uneducated guess is that maybe it has to do with launching with the PJRT runtime and the feature in wandb for attaching to a run in a subprocess not working well together.
Since you have access to the TPU machine, could I ask you, what happens if you comment out these three lines of code in Lightning: https://github.com/Lightning-AI/lightning/blob/00496da92d9e7d17c81f51c9abfb54583ba2817f/src/lightning/pytorch/loggers/wandb.py#L354-L356
Will it work?
@awaelchli Haven't done a git bisect yet, but downgrading to 2.0.4
fixed the issue for us. Will attempt one once we have more time and my hunch was that's only major change to happen to that codepath.
Only semi-related to the current issue, but rerunning the same script with the nightly build (as of now) raises an error. This is due to the local tpu
variable in xla.py not being defined when _XLA_GREATER_EQUAL_2_1
is false.
@alstonlo My bad! Let me fix that quickly
Opened https://github.com/Lightning-AI/lightning/pull/18085
Thanks!
@awaelchli I have installed lightning directly from #18085 and commented out the suggested lines. The training script runs but no WandB run is ever created and nothing is logged to WandB.
One way to reduce the surface of issues would be to do
import lightning as L
from lightning.pytorch.loggers.wandb import WandbLogger
def fn(fabric, logger):
...
logger = WandbLogger()
fabric = L.Fabric(accelerator="tpu")
fabric.launch(fn, logger)
While trying to find a solution for this issue, I think I may have stumbled upon another potential bug (which I suspect may be causing this issue, but I am not sure). For context, I noticed that if I added the following to the LightningModule
:
import torch_xla.core.xla_model as xm
from lightning.pytorch.utilities.rank_zero import rank_zero_only
class LinearRegression(pl.LightningModule):
def setup(self, stage):
print(f"{rank_zero_only.rank = }, {self.trainer.global_rank = }, {xm.get_ordinal() = }")
then there was a mismatch between rank_zero_only.rank
and self.trainer.global_rank
(and xm.get_ordinal()
agrees with the latter). I think this issue is caused by an interaction between rank_zero_only
and xm.rendezvous()
(which is called at various points of the Trainer
setup). The following is a minimal example:
# debug.py
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from lightning.pytorch.utilities.rank_zero import rank_zero_only
def f(index):
rank_zero_only.rank = xm.get_ordinal()
xm.rendezvous("barrier")
print(f"{rank_zero_only.rank = }, {xm.get_ordinal() = }")
if __name__ == "__main__":
xmp.spawn(f, args=tuple())
$ PJRT_DEVICE=TPU python3 -m debug
rank_zero_only.rank = 5, xm.get_ordinal() = 4
rank_zero_only.rank = 3, xm.get_ordinal() = 2
rank_zero_only.rank = 1, xm.get_ordinal() = 0
rank_zero_only.rank = 7, xm.get_ordinal() = 6
rank_zero_only.rank = 5, xm.get_ordinal() = 5
rank_zero_only.rank = 1, xm.get_ordinal() = 1
rank_zero_only.rank = 7, xm.get_ordinal() = 7
rank_zero_only.rank = 3, xm.get_ordinal() = 3
If I comment out the xm.rendezvous("barrier")
line, then I get
$ PJRT_DEVICE=TPU python3 -m debug
rank_zero_only.rank = 4, xm.get_ordinal() = 4
rank_zero_only.rank = 5, xm.get_ordinal() = 5
rank_zero_only.rank = 2, xm.get_ordinal() = 2
rank_zero_only.rank = 3, xm.get_ordinal() = 3
rank_zero_only.rank = 0, xm.get_ordinal() = 0
rank_zero_only.rank = 1, xm.get_ordinal() = 1
rank_zero_only.rank = 6, xm.get_ordinal() = 6
rank_zero_only.rank = 7, xm.get_ordinal() = 7
If I had instead assigned xm.get_ordinal()
to a local variable like so:
def f(index):
tmp = xm.get_ordinal()
xm.rendezvous("barrier")
print(f"{tmp = } {xm.get_ordinal() = }")
then tmp
and xm.get_ordinal()
match, so I think this is an issue with rank_zero_only.rank
.
While trying to find a solution for this issue, I think I may have stumbled upon another potential bug (which I suspect may be causing this issue, but I am not sure). For context, I noticed that if I added the following to the
LightningModule
:import torch_xla.core.xla_model as xm from lightning.pytorch.utilities.rank_zero import rank_zero_only class LinearRegression(pl.LightningModule): def setup(self, stage): print(f"{rank_zero_only.rank = }, {self.trainer.global_rank = }, {xm.get_ordinal() = }")
then there was a mismatch between
rank_zero_only.rank
andself.trainer.global_rank
(andxm.get_ordinal()
agrees with the latter). I think this issue is caused by an interaction betweenrank_zero_only
andxm.rendezvous()
(which is called at various points of theTrainer
setup). The following is a minimal example:# debug.py import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp from lightning.pytorch.utilities.rank_zero import rank_zero_only def f(index): rank_zero_only.rank = xm.get_ordinal() xm.rendezvous("barrier") print(f"{rank_zero_only.rank = }, {xm.get_ordinal() = }") if __name__ == "__main__": xmp.spawn(f, args=tuple())
$ PJRT_DEVICE=TPU python3 -m debug rank_zero_only.rank = 5, xm.get_ordinal() = 4 rank_zero_only.rank = 3, xm.get_ordinal() = 2 rank_zero_only.rank = 1, xm.get_ordinal() = 0 rank_zero_only.rank = 7, xm.get_ordinal() = 6 rank_zero_only.rank = 5, xm.get_ordinal() = 5 rank_zero_only.rank = 1, xm.get_ordinal() = 1 rank_zero_only.rank = 7, xm.get_ordinal() = 7 rank_zero_only.rank = 3, xm.get_ordinal() = 3
If I comment out the
xm.rendezvous("barrier")
line, then I get$ PJRT_DEVICE=TPU python3 -m debug rank_zero_only.rank = 4, xm.get_ordinal() = 4 rank_zero_only.rank = 5, xm.get_ordinal() = 5 rank_zero_only.rank = 2, xm.get_ordinal() = 2 rank_zero_only.rank = 3, xm.get_ordinal() = 3 rank_zero_only.rank = 0, xm.get_ordinal() = 0 rank_zero_only.rank = 1, xm.get_ordinal() = 1 rank_zero_only.rank = 6, xm.get_ordinal() = 6 rank_zero_only.rank = 7, xm.get_ordinal() = 7
If I had instead assigned
xm.get_ordinal()
to a local variable like so:def f(index): tmp = xm.get_ordinal() xm.rendezvous("barrier") print(f"{tmp = } {xm.get_ordinal() = }")
then
tmp
andxm.get_ordinal()
match, so I think this is an issue withrank_zero_only.rank
.
The xmp.spawn()
on v3 TPUs is multi-process and multi-thread. There are 4 processes for 4 chips, and 2 threads in each process for each core in a chip. So the rank_zero_only
object is shared between 2 threads, and that's why modifying one would cause 2 rank_zero_only.rank
to have same value. Without xm.rendezvous("barrier")
, the print value seems to be right, but it's only transient, and if you sleep
for 5 seconds and print again, they would be same as the wrong one.
This is actually the reason why trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
is needed on the Lightning code. The shared objects between the threads need to be decoupled.
@will-cromar
Isn't this a matter of a delayed init after forking? This fixes wandb from initializing 4 times (on a vx-8) and having mixed stream ids.
@@ -59,12 +59,14 @@
data = LinearDataModule()
model = LinearRegression()
+ logger=pl.loggers.WandbLogger(project="tpu_debug")
+ logger.experiment
trainer = pl.Trainer(
accelerator="tpu",
devices=8,
enable_checkpointing=False,
precision="bf16-mixed",
- logger=pl.loggers.WandbLogger(project="tpu_debug"),
+ logger=logger,
max_epochs=100,
enable_progress_bar=True,
)
however there will also be these 4 of these warnings from trying to create new session:
.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py:391: UserWarning: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
rank_zero_warn(
and it will hang (possibly related: https://docs.wandb.ai/guides/integrations/lightning#how-to-use-multiple-gpus-with-lightning-and-wb).
I'm not sure what the proper patch would be within lightning.
After debugging this for a bit, the issue is that you have to wandb.login
before the fit (before the forks?). eg:
@@ -59,12 +59,14 @@
data = LinearDataModule()
model = LinearRegression()
+ import wandb
+ wandb.login()
trainer = pl.Trainer(
accelerator="tpu",
devices=8,
as an aside, I had a (user) issue with consolidating all under one run:
TLDR
Either set things up on Google's TPU VMs via:
python3 -m pip install --upgrade pip
python3 -m pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip wandb -U
or the _WANDB_GREATER_EQUAL_0_12_10
check needs to be less strict
More details
The _WANDB_GREATER_EQUAL_0_12_10
check failure skips the pickling hack to unify the runs.
(Pdb) p _WANDB_GREATER_EQUAL_0_12_10
ContextualVersionConflict: (urllib3 1.25.8 (/usr/lib/python3/dist-packages), Requirement.parse('urllib3>=1.26.11; python_version >= "3.6"'), {'sentry-sdk'}). HINT: Try running `pip install -U 'wandb>=0.12.10'`
pip on a fresh Google's --version=tpu-vm-pt-2.0
TPU VM is 20.0.2
, so it doesn't have requirements backtracking.
The VM also installed pip via apt, so doing python3 -m pip install --upgrade pip
doesn't update pip
on the default PATH.
This is my user error in not updating+using the right pip/fixing all the env warnings, but maybe the RequirementCache
class might be a bit too strict, and it should just check if the version number is satisfied rather than if all the sub-requirements are also satisfied.
Hi @s22chan
After debugging this for a bit, the issue is that you have to wandb.login before the fit (before the forks?). eg:
I recommend that you do wandb login
in the command line instead (one time only). Then you will be automatically logged in whenever you call wandb in Python.
Regarding the other issue:
We have this trick in the logger to init the experiment when processes get launched (see comment in the code): https://github.com/Lightning-AI/lightning/blob/6511ac28759718a524dd00e627c186fb6baea763/src/lightning/pytorch/loggers/wandb.py#L349-L356
It would be very helpful if you could check whether this code path gets triggered or not in your case.
I didn't fully understand your comment about _WANDB_GREATER_EQUAL_0_12_10
. Are you saying you have wandb>=0.12.10 installed, yet the check failed and defaulted to False? If so, we could consider setting this version as the minimum required version, so we don't have to check it in the first place.
@awaelchli sorry if the messages were a bit scattered yesterday.
I recommend that you do wandb login in the command line instead (one time only).
I've already done that. The wandb.login()
before the fork/spawn is required to avoid a datarace between the two TPU threads launched on rank 0 for the logger init, which leads to the original reported crash.
@alstonlo is inferring that much of the rank_zero mechanisms in place for logging/profiling(/other?) doesn't work in a TPU scenario with the PJRT change because there are now two threads that have rank 0.
I didn't fully understand your comment about _WANDB_GREATER_EQUAL_0_12_10.
Wandb was wandb==0.15.7
, but because of a conflict in urllib3
(which is a sub-dependency of wandb
), the bool cast from RequirementCache
fails. This is super not obvious as a user.
any updates on this issue?
related: https://github.com/Lightning-AI/pytorch-lightning/issues/19035 (not wandb but logging and dataraces on the threads)