lightning-hydra-template icon indicating copy to clipboard operation
lightning-hydra-template copied to clipboard

How to run test_experiments if the experiment uses wandb callbacks?

Open RafayAK opened this issue 3 years ago • 0 comments

Hi, I saw that wandb callbacks can't be used when running using the trainer flag ++trainer.fast_dev_run=true. This causes a problem when running test cases on experiments that use wandb callbacks from the branch wandb-callbacks. The exact error is:

>>> pytest tests/test_sweeps.py::test_experiments
...
E             rank_zero_warn(
E           Error executing job with overrides: ['experiment=moons_experiment', '++trainer.fast_dev_run=True', 'logger=[]']
E           Traceback (most recent call last):
E             File "/home/rafay/Documents/data-augmentation/src/train.py", line 26, in main
E               metric_dict, _ = train(cfg)
E             File "/home/rafay/Documents/data-augmentation/src/utils/utils.py", line 42, in wrap
E               raise ex
E             File "/home/rafay/Documents/data-augmentation/src/utils/utils.py", line 39, in wrap
E               metric_dict, object_dict = task_func(cfg=cfg)
E             File "/home/rafay/Documents/data-augmentation/src/tasks/train_task.py", line 63, in train
E               trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
E             File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
E               self._call_and_handle_interrupt(
E             File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
E               return trainer_fn(*args, **kwargs)
E             File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
E               results = self._run(model, ckpt_path=self.ckpt_path)
E             File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
E               results = self._run_stage()
E             File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
E               return self._run_train()
E             File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1353, in _run_train
E               self.fit_loop.run()
E             File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/loops/base.py", line 199, in run
E               self.on_run_start(*args, **kwargs)
E             File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 217, in on_run_start
E               self.trainer._call_callback_hooks("on_train_start")
E             File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1636, in _call_callback_hooks
E               fn(self, self.lightning_module, *args, **kwargs)
E             File "/home/rafay/.local/share/virtualenvs/data-augmentation-0AjDNs-r/lib/python3.10/site-packages/pytorch_lightning/utilities/rank_zero.py", line 32, in wrapped_fn
E               return fn(*args, **kwargs)
E             File "/home/rafay/Documents/data-augmentation/src/callbacks/wandb_callbacks.py", line 54, in on_train_start
E               logger = get_wandb_logger(trainer=trainer)
E             File "/home/rafay/Documents/data-augmentation/src/callbacks/wandb_callbacks.py", line 27, in get_wandb_logger
E               raise Exception(
E           Exception: Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode.
E           
E           Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

tests/helpers/run_sh_command.py:19: Failed

Basically, the problem stems from the exception in the get_wandb_logger function:

def get_wandb_logger(trainer: Trainer) -> WandbLogger:
    """Safely get Weights&Biases logger from Trainer."""

    if trainer.fast_dev_run:
        raise Exception(
            "Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode."
        )

    if isinstance(trainer.logger, WandbLogger):
        return trainer.logger
...

My experiment overrides look as follows:

# @package _global_

# to execute this experiment run:
# python train.py experiment=moons_experiment.yaml
defaults:
  - override /datamodule: moons.yaml
  - override /model: mlp.yaml
  - override /callbacks: default_with_wandb_callbacks.yaml
  - override /logger: wandb.yaml
  - override /trainer: default.yaml

and the default_with_wandb_callbacks.yaml file looks as follows;

defaults:
    - default.yaml
    - wandb.yaml

So, how would you suggest running test cases in this instance?

RafayAK avatar Aug 08 '22 12:08 RafayAK

As the exception message says, lightning disables loggers in fast_dev_run mode so you shouldn't use it to debug or test wandb callbacks.

There are however other flags you could use, for example using only 1% of the training data:

python train.py debug=limit

# or
python train.py +trainer.limit_train_batches=0.01 +trainer.limit_val_batches=0.1 +trainer.limit_test_batches=0.1

ashleve avatar Aug 17 '22 11:08 ashleve

Thanks, that should do it. I'll also just override the callbacks to just use defaults when testing.

RafayAK avatar Aug 23 '22 17:08 RafayAK