How to run test_experiments if the experiment uses wandb callbacks?
Hi, I saw that wandb callbacks can't be used when running using the trainer flag ++trainer.fast_dev_run=true. This causes a problem when running test cases on experiments that use wandb callbacks from the branch wandb-callbacks. The exact error is:
>>> pytest tests/test_sweeps.py::test_experiments
...
E rank_zero_warn(
E Error executing job with overrides: ['experiment=moons_experiment', '++trainer.fast_dev_run=True', 'logger=[]']
E Traceback (most recent call last):
E File "/home/rafay/Documents/data-augmentation/src/train.py", line 26, in main
E metric_dict, _ = train(cfg)
E File "/home/rafay/Documents/data-augmentation/src/utils/utils.py", line 42, in wrap
E raise ex
E File "/home/rafay/Documents/data-augmentation/src/utils/utils.py", line 39, in wrap
E metric_dict, object_dict = task_func(cfg=cfg)
E File "/home/rafay/Documents/data-augmentation/src/tasks/train_task.py", line 63, in train
E trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
E File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
E self._call_and_handle_interrupt(
E File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
E return trainer_fn(*args, **kwargs)
E File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
E results = self._run(model, ckpt_path=self.ckpt_path)
E File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
E results = self._run_stage()
E File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
E return self._run_train()
E File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1353, in _run_train
E self.fit_loop.run()
E File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/loops/base.py", line 199, in run
E self.on_run_start(*args, **kwargs)
E File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 217, in on_run_start
E self.trainer._call_callback_hooks("on_train_start")
E File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1636, in _call_callback_hooks
E fn(self, self.lightning_module, *args, **kwargs)
E File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/utilities/rank_zero.py", line 32, in wrapped_fn
E return fn(*args, **kwargs)
E File "/home/rafay/Documents/data-augmentation/src/callbacks/wandb_callbacks.py", line 54, in on_train_start
E logger = get_wandb_logger(trainer=trainer)
E File "/home/rafay/Documents/data-augmentation/src/callbacks/wandb_callbacks.py", line 27, in get_wandb_logger
E raise Exception(
E Exception: Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode.
E
E Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
tests/helpers/run_sh_command.py:19: Failed
Basically, the problem stems from the exception in the get_wandb_logger function:
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
"""Safely get Weights&Biases logger from Trainer."""
if trainer.fast_dev_run:
raise Exception(
"Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode."
)
if isinstance(trainer.logger, WandbLogger):
return trainer.logger
...
My experiment overrides look as follows:
# @package _global_
# to execute this experiment run:
# python train.py experiment=moons_experiment.yaml
defaults:
- override /datamodule: moons.yaml
- override /model: mlp.yaml
- override /callbacks: default_with_wandb_callbacks.yaml
- override /logger: wandb.yaml
- override /trainer: default.yaml
and the default_with_wandb_callbacks.yaml file looks as follows;
defaults:
- default.yaml
- wandb.yaml
So, how would you suggest running test cases in this instance?
As the exception message says, lightning disables loggers in fast_dev_run mode so you shouldn't use it to debug or test wandb callbacks.
There are however other flags you could use, for example using only 1% of the training data:
python train.py debug=limit
# or
python train.py +trainer.limit_train_batches=0.01 +trainer.limit_val_batches=0.1 +trainer.limit_test_batches=0.1
Thanks, that should do it. I'll also just override the callbacks to just use defaults when testing.