pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

MLFlowLogger does not save config.yaml for each run

Open jeangud opened this issue 1 year ago • 0 comments

Bug description

The MLFlowLogger seems to save the config.yaml in the top-level save_dir (e.g. ./mlruns) directory (not even inside the experiment directory), instead of the specific run directory as for the other loggers. See below for minimal example. When running the same experiment twice, this results in an error because the config.yaml already exists.

Here is an example folder structure where you can see the config.yaml being at the top-level.

mlruns/
├── 557060468949431600 (experiment ID)
│   ├── 14625fca5e654f7faff19061b1ed44fa (run ID)
│   ├── 8b0a025336d6492391929adb37c18d2b (run ID)
│   └── meta.yaml
└── config.yaml

Expected behavior: just like with the default logger, we expect the config.yaml to be saved for inside the directory of each run of the given experiment.

mlruns/
└── 519079607625374876 (experiment ID)
    ├── 71d8f4b93eac490c8046d07bf7b49d31 (run ID)
    │   ├── ...
    │   └── config.yaml
    ├── 81a4e345f552487ea0d591e6bc14c881 (run ID)
    │   ├── ...
    │   └── config.yaml
    └── meta.yaml

Solution idea: two lines of interest seem to be:

Workaround 1: we can just avoid the error with LightningCLI(save_config_kwargs={"overwrite": True}) as suggested in the error message. However this does not save the config per-run.

Workaround 2: We can override cli.SaveConfigCallback.save_config to set save_to_log_dir=False, and implement logic to save in the correct folder by using the experiment ID and run ID.

from pathlib import Path

from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule

class MLFlowSaveConfigCallback(SaveConfigCallback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_to_log_dir = False

    def save_config(self, trainer, pl_module, stage):
        dir_runs = Path(trainer.logger.save_dir)
        dir_run = dir_runs / trainer.logger.experiment_id / trainer.logger.run_id
        path_config = dir_run / self.config_filename

        fs = get_filesystem(dir_run)
        fs.makedirs(dir_run, exist_ok=True)

        self.parser.save(
            self.config, path_config, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
        )

def cli_main():
    LightningCLI(DemoModel, BoringDataModule,
                 save_config_callback=MLFlowSaveConfigCallback)

if __name__ == "__main__":
    cli_main()

What version are you seeing the problem on?

v2.4

How to reproduce the bug

With the files below, run python main.py fit --config config.yaml twice. The first run will succeed, and the second one will fail with the error message below.

main.py

from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule

def cli_main():
    LightningCLI(DemoModel, BoringDataModule)

if __name__ == "__main__":
    cli_main()

config.yaml

# lightning.pytorch==2.4.0
trainer:
  logger:
    class_path: lightning.pytorch.loggers.MLFlowLogger

Error messages and logs

RuntimeError: SaveConfigCallback expected ./mlruns/config.yaml to NOT exist. Aborting to avoid overwriting results of a previous run. You can delete the previous config file, set `LightningCLI(save_config_callback=None)` to disable config saving, or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.

Environment

Current environment
  • CUDA:
    • GPU:
      • NVIDIA RTX 2000 Ada Generation Laptop GPU
    • available: True
    • version: 12.1
  • Lightning:
    • efficientnet-pytorch: 0.7.1
    • lightning: 2.4.0
    • lightning-utilities: 0.11.3.post0
    • pytorch-lightning: 2.3.1
    • segmentation-models-pytorch: 0.3.3
    • torch: 2.3.1
    • torchgeo: 0.5.2
    • torchmetrics: 1.4.0.post0
    • torchvision: 0.18.1
  • Packages:
    • aenum: 3.1.15
    • affine: 2.4.0
    • aiohttp: 3.9.5
    • aiosignal: 1.3.1
    • albucore: 0.0.12
    • albumentations: 1.4.10
    • alembic: 1.13.2
    • aniso8601: 9.0.1
    • annotated-types: 0.7.0
    • antlr4-python3-runtime: 4.9.3
    • asttokens: 2.4.1
    • async-timeout: 4.0.3
    • attrs: 23.2.0
    • basemap: 1.4.1
    • basemap-data: 1.3.2
    • bitsandbytes: 0.43.1
    • blinker: 1.8.2
    • cachetools: 5.3.3
    • certifi: 2024.6.2
    • charset-normalizer: 3.3.2
    • click: 8.1.7
    • click-plugins: 1.1.1
    • cligj: 0.7.2
    • cloudpickle: 3.0.0
    • comm: 0.2.2
    • contourpy: 1.2.1
    • cycler: 0.12.1
    • databricks-sdk: 0.29.0
    • debugpy: 1.8.2
    • decorator: 5.1.1
    • deprecated: 1.2.14
    • docker: 7.1.0
    • docstring-parser: 0.16
    • efficientnet-pytorch: 0.7.1
    • einops: 0.8.0
    • entrypoints: 0.4
    • exceptiongroup: 1.2.1
    • executing: 2.0.1
    • filelock: 3.15.4
    • fiona: 1.9.6
    • flask: 3.0.3
    • fonttools: 4.53.0
    • frozenlist: 1.4.1
    • fsspec: 2024.6.1
    • gitdb: 4.0.11
    • gitpython: 3.1.43
    • google-auth: 2.33.0
    • graphene: 3.3
    • graphql-core: 3.2.3
    • graphql-relay: 3.2.0
    • greenlet: 3.0.3
    • gunicorn: 22.0.0
    • huggingface-hub: 0.23.4
    • hydra-core: 1.3.2
    • idna: 3.7
    • imageio: 2.34.2
    • importlib-metadata: 7.2.1
    • importlib-resources: 6.4.0
    • ipykernel: 6.29.5
    • ipython: 8.26.0
    • itsdangerous: 2.2.0
    • jedi: 0.19.1
    • jinja2: 3.1.4
    • joblib: 1.4.2
    • jsonargparse: 4.31.0
    • jupyter-client: 8.6.2
    • jupyter-core: 5.7.2
    • kiwisolver: 1.4.5
    • kornia: 0.7.3
    • kornia-rs: 0.1.4
    • lazy-loader: 0.4
    • lightly: 1.5.8
    • lightly-utils: 0.0.2
    • lightning: 2.4.0
    • lightning-utilities: 0.11.3.post0
    • mako: 1.3.5
    • markdown: 3.6
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.5
    • matplotlib: 3.8.4
    • matplotlib-inline: 0.1.7
    • mdurl: 0.1.2
    • mlflow: 2.15.1
    • mlflow-skinny: 2.15.1
    • mpmath: 1.3.0
    • multidict: 6.0.5
    • munch: 4.0.0
    • nest-asyncio: 1.6.0
    • networkx: 3.3
    • numpy: 1.26.4
    • nvidia-cublas-cu12: 12.1.3.1
    • nvidia-cuda-cupti-cu12: 12.1.105
    • nvidia-cuda-nvrtc-cu12: 12.1.105
    • nvidia-cuda-runtime-cu12: 12.1.105
    • nvidia-cudnn-cu12: 8.9.2.26
    • nvidia-cufft-cu12: 11.0.2.54
    • nvidia-curand-cu12: 10.3.2.106
    • nvidia-cusolver-cu12: 11.4.5.107
    • nvidia-cusparse-cu12: 12.1.0.106
    • nvidia-ml-py: 12.535.161
    • nvidia-nccl-cu12: 2.20.5
    • nvidia-nvjitlink-cu12: 12.5.82
    • nvidia-nvtx-cu12: 12.1.105
    • nvitop: 1.3.2
    • omegaconf: 2.3.0
    • opencv-python-headless: 4.10.0.84
    • opentelemetry-api: 1.26.0
    • opentelemetry-sdk: 1.26.0
    • opentelemetry-semantic-conventions: 0.47b0
    • packaging: 23.2
    • pandas: 2.2.2
    • parso: 0.8.4
    • pexpect: 4.9.0
    • pillow: 10.4.0
    • pip: 24.1.1
    • platformdirs: 4.2.2
    • pretrainedmodels: 0.7.4
    • prompt-toolkit: 3.0.47
    • protobuf: 5.27.2
    • psutil: 6.0.0
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • pyarrow: 15.0.2
    • pyasn1: 0.6.0
    • pyasn1-modules: 0.4.0
    • pydantic: 2.8.0
    • pydantic-core: 2.20.0
    • pygments: 2.18.0
    • pyparsing: 3.1.2
    • pyproj: 3.6.1
    • pyshp: 2.3.1
    • python-dateutil: 2.9.0.post0
    • pytorch-lightning: 2.3.1
    • pytz: 2024.1
    • pyyaml: 6.0.1
    • pyzmq: 26.0.3
    • querystring-parser: 1.2.4
    • rasterio: 1.3.10
    • requests: 2.32.3
    • rich: 13.7.1
    • rsa: 4.9
    • rtree: 1.2.0
    • safetensors: 0.4.3
    • scikit-image: 0.24.0
    • scikit-learn: 1.5.0
    • scipy: 1.14.0
    • segmentation-models-pytorch: 0.3.3
    • setuptools: 65.5.0
    • shapely: 2.0.4
    • six: 1.16.0
    • smmap: 5.0.1
    • snuggs: 1.4.7
    • sqlalchemy: 2.0.32
    • sqlparse: 0.5.1
    • stack-data: 0.6.3
    • sympy: 1.12.1
    • tensorboardx: 2.6.2.2
    • termcolor: 2.4.0
    • threadpoolctl: 3.5.0
    • tifffile: 2024.6.18
    • timm: 0.9.2
    • tomli: 2.0.1
    • torch: 2.3.1
    • torchgeo: 0.5.2
    • torchmetrics: 1.4.0.post0
    • torchvision: 0.18.1
    • tornado: 6.4.1
    • tqdm: 4.66.4
    • traitlets: 5.14.3
    • triton: 2.3.1
    • typeshed-client: 2.5.1
    • typing-extensions: 4.12.2
    • tzdata: 2024.1
    • urllib3: 2.2.2
    • wcwidth: 0.2.13
    • werkzeug: 3.0.3
    • wrapt: 1.16.0
    • yarl: 1.9.4
    • zipp: 3.19.2
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.10.14
    • release: 6.5.0-1025-oem
    • version: #26-Ubuntu SMP PREEMPT_DYNAMIC Tue Jun 18 12:35:22 UTC 2024

More info

No response

jeangud avatar Aug 10 '24 00:08 jeangud