MLFlowLogger does not save config.yaml for each run
Bug description
The MLFlowLogger seems to save the config.yaml in the top-level save_dir (e.g. ./mlruns) directory (not even inside the experiment directory), instead of the specific run directory as for the other loggers. See below for minimal example. When running the same experiment twice, this results in an error because the config.yaml already exists.
Here is an example folder structure where you can see the config.yaml being at the top-level.
mlruns/
├── 557060468949431600 (experiment ID)
│ ├── 14625fca5e654f7faff19061b1ed44fa (run ID)
│ ├── 8b0a025336d6492391929adb37c18d2b (run ID)
│ └── meta.yaml
└── config.yaml
Expected behavior: just like with the default logger, we expect the config.yaml to be saved for inside the directory of each run of the given experiment.
mlruns/
└── 519079607625374876 (experiment ID)
├── 71d8f4b93eac490c8046d07bf7b49d31 (run ID)
│ ├── ...
│ └── config.yaml
├── 81a4e345f552487ea0d591e6bc14c881 (run ID)
│ ├── ...
│ └── config.yaml
└── meta.yaml
Solution idea: two lines of interest seem to be:
Workaround 1: we can just avoid the error with LightningCLI(save_config_kwargs={"overwrite": True}) as suggested in the error message. However this does not save the config per-run.
Workaround 2: We can override cli.SaveConfigCallback.save_config to set save_to_log_dir=False, and implement logic to save in the correct folder by using the experiment ID and run ID.
from pathlib import Path
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
class MLFlowSaveConfigCallback(SaveConfigCallback):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.save_to_log_dir = False
def save_config(self, trainer, pl_module, stage):
dir_runs = Path(trainer.logger.save_dir)
dir_run = dir_runs / trainer.logger.experiment_id / trainer.logger.run_id
path_config = dir_run / self.config_filename
fs = get_filesystem(dir_run)
fs.makedirs(dir_run, exist_ok=True)
self.parser.save(
self.config, path_config, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
)
def cli_main():
LightningCLI(DemoModel, BoringDataModule,
save_config_callback=MLFlowSaveConfigCallback)
if __name__ == "__main__":
cli_main()
What version are you seeing the problem on?
v2.4
How to reproduce the bug
With the files below, run python main.py fit --config config.yaml twice. The first run will succeed, and the second one will fail with the error message below.
main.py
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule
def cli_main():
LightningCLI(DemoModel, BoringDataModule)
if __name__ == "__main__":
cli_main()
config.yaml
# lightning.pytorch==2.4.0
trainer:
logger:
class_path: lightning.pytorch.loggers.MLFlowLogger
Error messages and logs
RuntimeError: SaveConfigCallback expected ./mlruns/config.yaml to NOT exist. Aborting to avoid overwriting results of a previous run. You can delete the previous config file, set `LightningCLI(save_config_callback=None)` to disable config saving, or set `LightningCLI(save_config_kwargs={"overwrite": True})` to overwrite the config file.
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA RTX 2000 Ada Generation Laptop GPU
- available: True
- version: 12.1
- GPU:
- Lightning:
- efficientnet-pytorch: 0.7.1
- lightning: 2.4.0
- lightning-utilities: 0.11.3.post0
- pytorch-lightning: 2.3.1
- segmentation-models-pytorch: 0.3.3
- torch: 2.3.1
- torchgeo: 0.5.2
- torchmetrics: 1.4.0.post0
- torchvision: 0.18.1
- Packages:
- aenum: 3.1.15
- affine: 2.4.0
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- albucore: 0.0.12
- albumentations: 1.4.10
- alembic: 1.13.2
- aniso8601: 9.0.1
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- asttokens: 2.4.1
- async-timeout: 4.0.3
- attrs: 23.2.0
- basemap: 1.4.1
- basemap-data: 1.3.2
- bitsandbytes: 0.43.1
- blinker: 1.8.2
- cachetools: 5.3.3
- certifi: 2024.6.2
- charset-normalizer: 3.3.2
- click: 8.1.7
- click-plugins: 1.1.1
- cligj: 0.7.2
- cloudpickle: 3.0.0
- comm: 0.2.2
- contourpy: 1.2.1
- cycler: 0.12.1
- databricks-sdk: 0.29.0
- debugpy: 1.8.2
- decorator: 5.1.1
- deprecated: 1.2.14
- docker: 7.1.0
- docstring-parser: 0.16
- efficientnet-pytorch: 0.7.1
- einops: 0.8.0
- entrypoints: 0.4
- exceptiongroup: 1.2.1
- executing: 2.0.1
- filelock: 3.15.4
- fiona: 1.9.6
- flask: 3.0.3
- fonttools: 4.53.0
- frozenlist: 1.4.1
- fsspec: 2024.6.1
- gitdb: 4.0.11
- gitpython: 3.1.43
- google-auth: 2.33.0
- graphene: 3.3
- graphql-core: 3.2.3
- graphql-relay: 3.2.0
- greenlet: 3.0.3
- gunicorn: 22.0.0
- huggingface-hub: 0.23.4
- hydra-core: 1.3.2
- idna: 3.7
- imageio: 2.34.2
- importlib-metadata: 7.2.1
- importlib-resources: 6.4.0
- ipykernel: 6.29.5
- ipython: 8.26.0
- itsdangerous: 2.2.0
- jedi: 0.19.1
- jinja2: 3.1.4
- joblib: 1.4.2
- jsonargparse: 4.31.0
- jupyter-client: 8.6.2
- jupyter-core: 5.7.2
- kiwisolver: 1.4.5
- kornia: 0.7.3
- kornia-rs: 0.1.4
- lazy-loader: 0.4
- lightly: 1.5.8
- lightly-utils: 0.0.2
- lightning: 2.4.0
- lightning-utilities: 0.11.3.post0
- mako: 1.3.5
- markdown: 3.6
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- matplotlib: 3.8.4
- matplotlib-inline: 0.1.7
- mdurl: 0.1.2
- mlflow: 2.15.1
- mlflow-skinny: 2.15.1
- mpmath: 1.3.0
- multidict: 6.0.5
- munch: 4.0.0
- nest-asyncio: 1.6.0
- networkx: 3.3
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu12: 8.9.2.26
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-ml-py: 12.535.161
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.5.82
- nvidia-nvtx-cu12: 12.1.105
- nvitop: 1.3.2
- omegaconf: 2.3.0
- opencv-python-headless: 4.10.0.84
- opentelemetry-api: 1.26.0
- opentelemetry-sdk: 1.26.0
- opentelemetry-semantic-conventions: 0.47b0
- packaging: 23.2
- pandas: 2.2.2
- parso: 0.8.4
- pexpect: 4.9.0
- pillow: 10.4.0
- pip: 24.1.1
- platformdirs: 4.2.2
- pretrainedmodels: 0.7.4
- prompt-toolkit: 3.0.47
- protobuf: 5.27.2
- psutil: 6.0.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- pyarrow: 15.0.2
- pyasn1: 0.6.0
- pyasn1-modules: 0.4.0
- pydantic: 2.8.0
- pydantic-core: 2.20.0
- pygments: 2.18.0
- pyparsing: 3.1.2
- pyproj: 3.6.1
- pyshp: 2.3.1
- python-dateutil: 2.9.0.post0
- pytorch-lightning: 2.3.1
- pytz: 2024.1
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- querystring-parser: 1.2.4
- rasterio: 1.3.10
- requests: 2.32.3
- rich: 13.7.1
- rsa: 4.9
- rtree: 1.2.0
- safetensors: 0.4.3
- scikit-image: 0.24.0
- scikit-learn: 1.5.0
- scipy: 1.14.0
- segmentation-models-pytorch: 0.3.3
- setuptools: 65.5.0
- shapely: 2.0.4
- six: 1.16.0
- smmap: 5.0.1
- snuggs: 1.4.7
- sqlalchemy: 2.0.32
- sqlparse: 0.5.1
- stack-data: 0.6.3
- sympy: 1.12.1
- tensorboardx: 2.6.2.2
- termcolor: 2.4.0
- threadpoolctl: 3.5.0
- tifffile: 2024.6.18
- timm: 0.9.2
- tomli: 2.0.1
- torch: 2.3.1
- torchgeo: 0.5.2
- torchmetrics: 1.4.0.post0
- torchvision: 0.18.1
- tornado: 6.4.1
- tqdm: 4.66.4
- traitlets: 5.14.3
- triton: 2.3.1
- typeshed-client: 2.5.1
- typing-extensions: 4.12.2
- tzdata: 2024.1
- urllib3: 2.2.2
- wcwidth: 0.2.13
- werkzeug: 3.0.3
- wrapt: 1.16.0
- yarl: 1.9.4
- zipp: 3.19.2
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.14
- release: 6.5.0-1025-oem
- version: #26-Ubuntu SMP PREEMPT_DYNAMIC Tue Jun 18 12:35:22 UTC 2024
More info
No response