pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

MLFlow logger with remote tracking fails with CLI

Open Benjamin-Etheredge opened this issue 2 years ago • 9 comments

Bug description

Running with the LightningCLI, MLflow logger, and MLFLOW_TRACKING_URI environment variable set causes an assertion failure with logging. I think using a remote tracking server causes no local log files to be created which the CLI doesn't like.

I suspect it's a similar issue to #12748.

How to reproduce the bug

from pytorch_lightning.cli import LightningCLI
from helpers import BoringModel, BoringDataModule

cli = LightningCLI(
    BoringModel, 
    BoringDataModule, 
    trainer_defaults=dict(
        max_epochs=1,
        logger="pytorch_lightning.loggers.MLFlowLogger"
    )
)
$ mlflow server
...
$ MLFLOW_TRACKING_URI=http://localhost:5000 python main.py fit

Error messages and logs

/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:106: UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
Traceback (most recent call last):
  File "/workspaces/mlflow_log_error/main.py", line 4, in <module>
    cli = LightningCLI(
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 354, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 665, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit
    call._call_and_handle_interrupt(
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1037, in _run
    self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _call_setup_hook
    self._call_callback_hooks("setup", stage=fn)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1380, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 216, in setup
    assert log_dir is not None
AssertionError

Environment

* CUDA:
        - GPU:               None
        - available:         False
        - version:           11.7
* Lightning:
        - lightning:         1.9.0rc0
        - lightning-cloud:   0.5.16
        - lightning-utilities: 0.5.0
        - pytorch-lightning: 1.8.6
        - torch:             1.13.1
        - torchmetrics:      0.11.0
        - torchvision:       0.14.1
* Packages:
        - aiohttp:           3.8.3
        - aiosignal:         1.3.1
        - alembic:           1.9.1
        - antlr4-python3-runtime: 4.9.3
        - anyio:             3.6.2
        - argon2-cffi:       21.3.0
        - argon2-cffi-bindings: 21.2.0
        - arrow:             1.2.3
        - asttokens:         2.2.1
        - async-timeout:     4.0.2
        - attrs:             22.2.0
        - babel:             2.11.0
        - backcall:          0.2.0
        - beautifulsoup4:    4.11.1
        - bleach:            5.0.1
        - blessed:           1.19.1
        - build:             0.9.0
        - certifi:           2022.12.7
        - cffi:              1.15.1
        - charset-normalizer: 2.1.1
        - click:             8.1.3
        - cloudpickle:       2.2.0
        - comm:              0.1.2
        - commonmark:        0.9.1
        - contourpy:         1.0.6
        - croniter:          1.3.8
        - cycler:            0.11.0
        - databricks-cli:    0.17.4
        - dateutils:         0.6.12
        - debugpy:           1.6.5
        - decorator:         5.1.1
        - deepdiff:          6.2.3
        - defusedxml:        0.7.1
        - dnspython:         2.2.1
        - docker:            6.0.1
        - docstring-parser:  0.15
        - email-validator:   1.3.0
        - entrypoints:       0.4
        - executing:         1.2.0
        - fastapi:           0.88.0
        - fastjsonschema:    2.16.2
        - flask:             2.2.2
        - fonttools:         4.38.0
        - fqdn:              1.5.1
        - frozenlist:        1.3.3
        - fsspec:            2022.11.0
        - gitdb:             4.0.10
        - gitpython:         3.1.30
        - greenlet:          2.0.1
        - gunicorn:          20.1.0
        - h11:               0.14.0
        - httpcore:          0.16.3
        - httptools:         0.5.0
        - httpx:             0.23.3
        - hydra-core:        1.3.1
        - idna:              3.4
        - importlib-metadata: 5.2.0
        - importlib-resources: 5.10.2
        - inquirer:          3.1.2
        - ipykernel:         6.20.1
        - ipython:           8.8.0
        - ipython-genutils:  0.2.0
        - isoduration:       20.11.0
        - itsdangerous:      2.1.2
        - jedi:              0.18.2
        - jinja2:            3.1.2
        - joblib:            1.2.0
        - json5:             0.9.11
        - jsonargparse:      4.19.0
        - jsonpointer:       2.3
        - jsonschema:        4.17.3
        - jupyter-client:    7.4.8
        - jupyter-core:      5.1.3
        - jupyter-events:    0.6.0
        - jupyter-server:    2.0.6
        - jupyter-server-terminals: 0.4.4
        - jupyterlab:        3.5.2
        - jupyterlab-pygments: 0.2.2
        - jupyterlab-server: 2.18.0
        - kiwisolver:        1.4.4
        - lightning:         1.9.0rc0
        - lightning-cloud:   0.5.16
        - lightning-utilities: 0.5.0
        - llvmlite:          0.39.1
        - mako:              1.2.4
        - markdown:          3.4.1
        - markupsafe:        2.1.1
        - matplotlib:        3.6.2
        - matplotlib-inline: 0.1.6
        - mistune:           2.0.4
        - mlflow:            2.1.1
        - multidict:         6.0.4
        - nbclassic:         0.4.8
        - nbclient:          0.7.2
        - nbconvert:         7.2.7
        - nbformat:          5.7.1
        - nest-asyncio:      1.5.6
        - notebook:          6.5.2
        - notebook-shim:     0.2.2
        - numba:             0.56.4
        - numpy:             1.23.5
        - nvidia-cublas-cu11: 11.10.3.66
        - nvidia-cuda-nvrtc-cu11: 11.7.99
        - nvidia-cuda-runtime-cu11: 11.7.99
        - nvidia-cudnn-cu11: 8.5.0.96
        - oauthlib:          3.2.2
        - omegaconf:         2.3.0
        - ordered-set:       4.1.0
        - orjson:            3.8.4
        - packaging:         21.3
        - pandas:            1.5.2
        - pandocfilters:     1.5.0
        - parso:             0.8.3
        - pep517:            0.13.0
        - pexpect:           4.8.0
        - pickleshare:       0.7.5
        - pillow:            9.4.0
        - pip:               22.3.1
        - pip-tools:         6.12.1
        - platformdirs:      2.6.2
        - prometheus-client: 0.15.0
        - prompt-toolkit:    3.0.36
        - protobuf:          3.20.1
        - psutil:            5.9.4
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - pyarrow:           10.0.1
        - pycparser:         2.21
        - pydantic:          1.10.4
        - pygments:          2.14.0
        - pyjwt:             2.6.0
        - pyparsing:         3.0.9
        - pyrsistent:        0.19.3
        - python-dateutil:   2.8.2
        - python-dotenv:     0.21.0
        - python-editor:     1.0.4
        - python-json-logger: 2.0.4
        - python-multipart:  0.0.5
        - pytorch-lightning: 1.8.6
        - pytz:              2022.7
        - pyyaml:            6.0
        - pyzmq:             24.0.1
        - querystring-parser: 1.2.4
        - readchar:          4.0.3
        - requests:          2.28.1
        - rfc3339-validator: 0.1.4
        - rfc3986:           1.5.0
        - rfc3986-validator: 0.1.1
        - rich:              13.0.1
        - scikit-learn:      1.2.0
        - scipy:             1.10.0
        - send2trash:        1.8.0
        - setuptools:        65.5.0
        - shap:              0.41.0
        - six:               1.16.0
        - slicer:            0.0.7
        - smmap:             5.0.0
        - sniffio:           1.3.0
        - soupsieve:         2.3.2.post1
        - sqlalchemy:        1.4.46
        - sqlparse:          0.4.3
        - stack-data:        0.6.2
        - starlette:         0.22.0
        - starsessions:      1.3.0
        - tabulate:          0.9.0
        - tensorboardx:      2.5.1
        - terminado:         0.17.1
        - threadpoolctl:     3.1.0
        - tinycss2:          1.2.1
        - tomli:             2.0.1
        - torch:             1.13.1
        - torchmetrics:      0.11.0
        - torchvision:       0.14.1
        - tornado:           6.2
        - tqdm:              4.64.1
        - traitlets:         5.8.1
        - typeshed-client:   2.1.0
        - typing-extensions: 4.4.0
        - ujson:             5.7.0
        - uri-template:      1.2.0
        - urllib3:           1.26.13
        - uvicorn:           0.20.0
        - uvloop:            0.17.0
        - watchfiles:        0.18.1
        - wcwidth:           0.2.5
        - webcolors:         1.12
        - webencodings:      0.5.1
        - websocket-client:  1.4.2
        - websockets:        10.4
        - werkzeug:          2.2.2
        - wheel:             0.38.4
        - yarl:              1.8.2
        - zipp:              3.11.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         
        - python:            3.10.9
        - version:           #1 SMP Wed Mar 2 00:30:59 UTC 2022

More info

No response

cc @carmocca @mauvilsa

Benjamin-Etheredge avatar Jan 09 '23 19:01 Benjamin-Etheredge

A temporary workaround for this issue is to declare a TensorBoard logger ahead of the MLflow one. Like so,

cli = LightningCLI(
    BoringModel, 
    BoringDataModule, 
    trainer_defaults=dict(
        max_epochs=1,
        logger=[
            {
                "class_path": "pytorch_lightning.loggers.TensorBoardLogger", 
                "init_args": {
                    "save_dir": "tb_logs",
                }
            },
            "pytorch_lightning.loggers.MLFlowLogger"
        ],
    )
)

Benjamin-Etheredge avatar Jan 25 '23 20:01 Benjamin-Etheredge

@Benjamin-Etheredge Here is my workaround, which still leverage the goodness of CLI module and its yaml file.

cli = LightningCLI(
    LightningToneClassifier,
    ToneDataModule,
    run=False,
)

with open("lightning/trainer_config.yaml", "r") as f:
    config = yaml.safe_load(f)
config["trainer"]["logger"] = MLFlowLogger(
    experiment_name="xxxx",
    tracking_uri="xxxx",
    log_model=True,
)
train_dataloader, val_dataloader = prepare_fit_dataloader(cli)
trainer = Trainer(**config["trainer"])
trainer.logger.log_hyperparams(config)
trainer.fit(cli.model, train_dataloader, val_dataloader)

vincentwu0730 avatar Feb 18 '23 07:02 vincentwu0730

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

stale[bot] avatar Apr 14 '23 06:04 stale[bot]

Hi @vincentwu0730 ,

thank you for your workaround! can you share what is prepare_fit_dataloader?

goncalomcorreia avatar Jul 27 '23 17:07 goncalomcorreia

The issue surfaces through the usage in LightningCLI because it calls the log dir, but the origin of the problem as suspected by @Benjamin-Etheredge is because the save_dir from MLFlowLogger returns None in case tracking is not done locally:

https://github.com/Lightning-AI/lightning/blob/41f0425a8dbd54030c5b711f92340dc8dc41c173/src/lightning/pytorch/loggers/mlflow.py#L299-L301

Two possible solutions that come to my mind to address this:

  1. Return a default local directory instead of None so LightningCLI can save the config
  2. In the LightningCLI, if the value returned by the log dir is None, save the config to a different place (as if there is no logger).

awaelchli avatar Jul 28 '23 13:07 awaelchli

Two possible solutions that come to my mind to address this:

I can suggest another solution. Implement a custom save config class that saves the config in mlflow as an artifact, instead of saving the config locally. If logging remotely it makes sense to also save the config in the same place.

mauvilsa avatar Jul 28 '23 14:07 mauvilsa

A realization of @mauvilsa idea:

from lightning.pytorch.cli import SaveConfigCallback
class MLFlowSaveConfigCallback(SaveConfigCallback):
    def __init__(self, parser, config, config_filename='config.yaml', overwrite=False, multifile=False):
        super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir=False)

    def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
        # Convert Namespace to dict
        config_dict = vars(self.config)

        # Log parameters to MLFlow
        pl_module.logger.log_hyperparams(config_dict)
def cli_compile_main():
    cli = LightningCLI(datamodule_class=PRDataModule, run=False, save_config_callback=MLFlowSaveConfigCallback)
    compiled_model = torch.compile(cli.model)
    cli.trainer.fit(compiled_model, datamodule=cli.datamodule)
    cli.trainer.test(datamodule=cli.datamodule)

terbed avatar Mar 06 '24 10:03 terbed

Slight modification of @terbed if you want to safe the file as yaml

from lightning.pytorch.cli import SaveConfigCallback
from lightning import Trainer, LightningModule
import tempfile


class MLFlowSaveConfigCallback(SaveConfigCallback):
    def __init__(self, parser, config, config_filename='config.yaml', overwrite=False, multifile=False):
        super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir=False)

    def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
        # convert namespace to dict
        config_dict = vars(self.config)

        if trainer.is_global_zero:
            with tempfile.TemporaryDirectory() as tmp_dir:
                config_path = Path(tmp_dir) / 'config.yaml'
                self.parser.save(
                    self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
                )
                trainer.logger.experiment.log_artifact(local_path=config_path,
                                                       run_id=trainer.logger.run_id)

adrianomartinelli avatar Jul 20 '24 07:07 adrianomartinelli

Slight refactoring of @adrianomartinelli code with a full example:

from pathlib import Path
import tempfile

from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
from lightning import Trainer, LightningModule


class MLFlowSaveConfigCallback(SaveConfigCallback):
    def __init__(self, parser, config, config_filename='config.yaml', overwrite=False, multifile=False):
        super().__init__(parser, config, config_filename,
                         overwrite, multifile, save_to_log_dir=False)

    def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
        if trainer.is_global_zero:
            with tempfile.TemporaryDirectory() as tmp_dir:
                config_path = Path(tmp_dir) / 'config.yaml'
                self.parser.save(
                    self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
                )
                trainer.logger.experiment.log_artifact(local_path=config_path,
                                                       run_id=trainer.logger.run_id)


def main():
    LightningCLI(save_config_callback=MLFlowSaveConfigCallback)


if __name__ == "__main__":
    main()

AlessandroW avatar Sep 24 '24 16:09 AlessandroW

I guess saving the config to MLflow is the best approach. It's unfortunate that then the user of the CLI must use MLflow and selecting a different logger from the command line doesn't work.

Perhaps SaveConfigCallback should skip saving the config, when log_dir is None, instead of failing an assertion. There's the save_to_log_dir argument, but it's not possible to set it to False, except when subclassing.

senarvi avatar Jul 11 '25 12:07 senarvi