MLFlow logger with remote tracking fails with CLI
Bug description
Running with the LightningCLI, MLflow logger, and MLFLOW_TRACKING_URI environment variable set causes an assertion failure with logging. I think using a remote tracking server causes no local log files to be created which the CLI doesn't like.
I suspect it's a similar issue to #12748.
How to reproduce the bug
from pytorch_lightning.cli import LightningCLI
from helpers import BoringModel, BoringDataModule
cli = LightningCLI(
BoringModel,
BoringDataModule,
trainer_defaults=dict(
max_epochs=1,
logger="pytorch_lightning.loggers.MLFlowLogger"
)
)
$ mlflow server
...
$ MLFLOW_TRACKING_URI=http://localhost:5000 python main.py fit
Error messages and logs
/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:106: UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
Traceback (most recent call last):
File "/workspaces/mlflow_log_error/main.py", line 4, in <module>
cli = LightningCLI(
File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 354, in __init__
self._run_subcommand(self.subcommand)
File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 665, in _run_subcommand
fn(**fn_kwargs)
File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit
call._call_and_handle_interrupt(
File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
self._run(model, ckpt_path=self.ckpt_path)
File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1037, in _run
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment
File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _call_setup_hook
self._call_callback_hooks("setup", stage=fn)
File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1380, in _call_callback_hooks
fn(self, self.lightning_module, *args, **kwargs)
File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 216, in setup
assert log_dir is not None
AssertionError
Environment
* CUDA:
- GPU: None
- available: False
- version: 11.7
* Lightning:
- lightning: 1.9.0rc0
- lightning-cloud: 0.5.16
- lightning-utilities: 0.5.0
- pytorch-lightning: 1.8.6
- torch: 1.13.1
- torchmetrics: 0.11.0
- torchvision: 0.14.1
* Packages:
- aiohttp: 3.8.3
- aiosignal: 1.3.1
- alembic: 1.9.1
- antlr4-python3-runtime: 4.9.3
- anyio: 3.6.2
- argon2-cffi: 21.3.0
- argon2-cffi-bindings: 21.2.0
- arrow: 1.2.3
- asttokens: 2.2.1
- async-timeout: 4.0.2
- attrs: 22.2.0
- babel: 2.11.0
- backcall: 0.2.0
- beautifulsoup4: 4.11.1
- bleach: 5.0.1
- blessed: 1.19.1
- build: 0.9.0
- certifi: 2022.12.7
- cffi: 1.15.1
- charset-normalizer: 2.1.1
- click: 8.1.3
- cloudpickle: 2.2.0
- comm: 0.1.2
- commonmark: 0.9.1
- contourpy: 1.0.6
- croniter: 1.3.8
- cycler: 0.11.0
- databricks-cli: 0.17.4
- dateutils: 0.6.12
- debugpy: 1.6.5
- decorator: 5.1.1
- deepdiff: 6.2.3
- defusedxml: 0.7.1
- dnspython: 2.2.1
- docker: 6.0.1
- docstring-parser: 0.15
- email-validator: 1.3.0
- entrypoints: 0.4
- executing: 1.2.0
- fastapi: 0.88.0
- fastjsonschema: 2.16.2
- flask: 2.2.2
- fonttools: 4.38.0
- fqdn: 1.5.1
- frozenlist: 1.3.3
- fsspec: 2022.11.0
- gitdb: 4.0.10
- gitpython: 3.1.30
- greenlet: 2.0.1
- gunicorn: 20.1.0
- h11: 0.14.0
- httpcore: 0.16.3
- httptools: 0.5.0
- httpx: 0.23.3
- hydra-core: 1.3.1
- idna: 3.4
- importlib-metadata: 5.2.0
- importlib-resources: 5.10.2
- inquirer: 3.1.2
- ipykernel: 6.20.1
- ipython: 8.8.0
- ipython-genutils: 0.2.0
- isoduration: 20.11.0
- itsdangerous: 2.1.2
- jedi: 0.18.2
- jinja2: 3.1.2
- joblib: 1.2.0
- json5: 0.9.11
- jsonargparse: 4.19.0
- jsonpointer: 2.3
- jsonschema: 4.17.3
- jupyter-client: 7.4.8
- jupyter-core: 5.1.3
- jupyter-events: 0.6.0
- jupyter-server: 2.0.6
- jupyter-server-terminals: 0.4.4
- jupyterlab: 3.5.2
- jupyterlab-pygments: 0.2.2
- jupyterlab-server: 2.18.0
- kiwisolver: 1.4.4
- lightning: 1.9.0rc0
- lightning-cloud: 0.5.16
- lightning-utilities: 0.5.0
- llvmlite: 0.39.1
- mako: 1.2.4
- markdown: 3.4.1
- markupsafe: 2.1.1
- matplotlib: 3.6.2
- matplotlib-inline: 0.1.6
- mistune: 2.0.4
- mlflow: 2.1.1
- multidict: 6.0.4
- nbclassic: 0.4.8
- nbclient: 0.7.2
- nbconvert: 7.2.7
- nbformat: 5.7.1
- nest-asyncio: 1.5.6
- notebook: 6.5.2
- notebook-shim: 0.2.2
- numba: 0.56.4
- numpy: 1.23.5
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cudnn-cu11: 8.5.0.96
- oauthlib: 3.2.2
- omegaconf: 2.3.0
- ordered-set: 4.1.0
- orjson: 3.8.4
- packaging: 21.3
- pandas: 1.5.2
- pandocfilters: 1.5.0
- parso: 0.8.3
- pep517: 0.13.0
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 9.4.0
- pip: 22.3.1
- pip-tools: 6.12.1
- platformdirs: 2.6.2
- prometheus-client: 0.15.0
- prompt-toolkit: 3.0.36
- protobuf: 3.20.1
- psutil: 5.9.4
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- pyarrow: 10.0.1
- pycparser: 2.21
- pydantic: 1.10.4
- pygments: 2.14.0
- pyjwt: 2.6.0
- pyparsing: 3.0.9
- pyrsistent: 0.19.3
- python-dateutil: 2.8.2
- python-dotenv: 0.21.0
- python-editor: 1.0.4
- python-json-logger: 2.0.4
- python-multipart: 0.0.5
- pytorch-lightning: 1.8.6
- pytz: 2022.7
- pyyaml: 6.0
- pyzmq: 24.0.1
- querystring-parser: 1.2.4
- readchar: 4.0.3
- requests: 2.28.1
- rfc3339-validator: 0.1.4
- rfc3986: 1.5.0
- rfc3986-validator: 0.1.1
- rich: 13.0.1
- scikit-learn: 1.2.0
- scipy: 1.10.0
- send2trash: 1.8.0
- setuptools: 65.5.0
- shap: 0.41.0
- six: 1.16.0
- slicer: 0.0.7
- smmap: 5.0.0
- sniffio: 1.3.0
- soupsieve: 2.3.2.post1
- sqlalchemy: 1.4.46
- sqlparse: 0.4.3
- stack-data: 0.6.2
- starlette: 0.22.0
- starsessions: 1.3.0
- tabulate: 0.9.0
- tensorboardx: 2.5.1
- terminado: 0.17.1
- threadpoolctl: 3.1.0
- tinycss2: 1.2.1
- tomli: 2.0.1
- torch: 1.13.1
- torchmetrics: 0.11.0
- torchvision: 0.14.1
- tornado: 6.2
- tqdm: 4.64.1
- traitlets: 5.8.1
- typeshed-client: 2.1.0
- typing-extensions: 4.4.0
- ujson: 5.7.0
- uri-template: 1.2.0
- urllib3: 1.26.13
- uvicorn: 0.20.0
- uvloop: 0.17.0
- watchfiles: 0.18.1
- wcwidth: 0.2.5
- webcolors: 1.12
- webencodings: 0.5.1
- websocket-client: 1.4.2
- websockets: 10.4
- werkzeug: 2.2.2
- wheel: 0.38.4
- yarl: 1.8.2
- zipp: 3.11.0
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor:
- python: 3.10.9
- version: #1 SMP Wed Mar 2 00:30:59 UTC 2022
More info
No response
cc @carmocca @mauvilsa
A temporary workaround for this issue is to declare a TensorBoard logger ahead of the MLflow one. Like so,
cli = LightningCLI(
BoringModel,
BoringDataModule,
trainer_defaults=dict(
max_epochs=1,
logger=[
{
"class_path": "pytorch_lightning.loggers.TensorBoardLogger",
"init_args": {
"save_dir": "tb_logs",
}
},
"pytorch_lightning.loggers.MLFlowLogger"
],
)
)
@Benjamin-Etheredge Here is my workaround, which still leverage the goodness of CLI module and its yaml file.
cli = LightningCLI(
LightningToneClassifier,
ToneDataModule,
run=False,
)
with open("lightning/trainer_config.yaml", "r") as f:
config = yaml.safe_load(f)
config["trainer"]["logger"] = MLFlowLogger(
experiment_name="xxxx",
tracking_uri="xxxx",
log_model=True,
)
train_dataloader, val_dataloader = prepare_fit_dataloader(cli)
trainer = Trainer(**config["trainer"])
trainer.logger.log_hyperparams(config)
trainer.fit(cli.model, train_dataloader, val_dataloader)
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!
Hi @vincentwu0730 ,
thank you for your workaround! can you share what is prepare_fit_dataloader?
The issue surfaces through the usage in LightningCLI because it calls the log dir, but the origin of the problem as suspected by @Benjamin-Etheredge is because the save_dir from MLFlowLogger returns None in case tracking is not done locally:
https://github.com/Lightning-AI/lightning/blob/41f0425a8dbd54030c5b711f92340dc8dc41c173/src/lightning/pytorch/loggers/mlflow.py#L299-L301
Two possible solutions that come to my mind to address this:
- Return a default local directory instead of None so LightningCLI can save the config
- In the LightningCLI, if the value returned by the log dir is None, save the config to a different place (as if there is no logger).
Two possible solutions that come to my mind to address this:
I can suggest another solution. Implement a custom save config class that saves the config in mlflow as an artifact, instead of saving the config locally. If logging remotely it makes sense to also save the config in the same place.
A realization of @mauvilsa idea:
from lightning.pytorch.cli import SaveConfigCallback
class MLFlowSaveConfigCallback(SaveConfigCallback):
def __init__(self, parser, config, config_filename='config.yaml', overwrite=False, multifile=False):
super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir=False)
def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
# Convert Namespace to dict
config_dict = vars(self.config)
# Log parameters to MLFlow
pl_module.logger.log_hyperparams(config_dict)
def cli_compile_main():
cli = LightningCLI(datamodule_class=PRDataModule, run=False, save_config_callback=MLFlowSaveConfigCallback)
compiled_model = torch.compile(cli.model)
cli.trainer.fit(compiled_model, datamodule=cli.datamodule)
cli.trainer.test(datamodule=cli.datamodule)
Slight modification of @terbed if you want to safe the file as yaml
from lightning.pytorch.cli import SaveConfigCallback
from lightning import Trainer, LightningModule
import tempfile
class MLFlowSaveConfigCallback(SaveConfigCallback):
def __init__(self, parser, config, config_filename='config.yaml', overwrite=False, multifile=False):
super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir=False)
def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
# convert namespace to dict
config_dict = vars(self.config)
if trainer.is_global_zero:
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = Path(tmp_dir) / 'config.yaml'
self.parser.save(
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
)
trainer.logger.experiment.log_artifact(local_path=config_path,
run_id=trainer.logger.run_id)
Slight refactoring of @adrianomartinelli code with a full example:
from pathlib import Path
import tempfile
from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
from lightning import Trainer, LightningModule
class MLFlowSaveConfigCallback(SaveConfigCallback):
def __init__(self, parser, config, config_filename='config.yaml', overwrite=False, multifile=False):
super().__init__(parser, config, config_filename,
overwrite, multifile, save_to_log_dir=False)
def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
if trainer.is_global_zero:
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = Path(tmp_dir) / 'config.yaml'
self.parser.save(
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
)
trainer.logger.experiment.log_artifact(local_path=config_path,
run_id=trainer.logger.run_id)
def main():
LightningCLI(save_config_callback=MLFlowSaveConfigCallback)
if __name__ == "__main__":
main()
I guess saving the config to MLflow is the best approach. It's unfortunate that then the user of the CLI must use MLflow and selecting a different logger from the command line doesn't work.
Perhaps SaveConfigCallback should skip saving the config, when log_dir is None, instead of failing an assertion. There's the save_to_log_dir argument, but it's not possible to set it to False, except when subclassing.