MLFlowLogger saves copies of ModelCheckpoints
Bug description
I am trying to achieve the following behavior:
-
ModelCheckpointcallbacks save model checkpoint files to a certain location -
MLFlowLogger(withlog_model=True) only references the saved checkpoints
The problem is that no matter what I do, MLFlowLogger tries to save copies of the checkpoints in a new location.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import (
RichProgressBar, EarlyStopping, ModelCheckpoint
)
def get_lightning_mlflow_logger(experiment_name: str, _artifact_location: str) -> pl_loggers.MLFlowLogger:
return pl_loggers.MLFlowLogger(
experiment_name=experiment_name,
run_name=datetime.now().isoformat(),
tracking_uri=os.path.join(EXPERIMENT_LOGS_DIR, './mlruns'),
log_model=True,
artifact_location=_artifact_location
)
def _configure_callbacks():
early_stopping = EarlyStopping(
monitor="val_loss",
mode='min',
patience=10,
stopping_threshold=0.05,
divergence_threshold=5.0
)
checkpoint_callback = ModelCheckpoint(
save_top_k=2,
save_last=True,
monitor="val_loss",
mode="min",
verbose=True
)
checkpoint_callback.CHECKPOINT_JOIN_CHAR = '_'
return (
[
early_stopping,
checkpoint_callback,
RichProgressBar()
],
checkpoint_callback.dirpath
)
def cli_main():
model = MNISTClassifier()
data_module = MNISTDataModule()
callbacks, checkpoints_dirpath = _configure_callbacks()
print(f'ModelCheckpoint Callback dirpath: {checkpoints_dirpath}')
mlflow_logger = get_lightning_mlflow_logger(EXPERIMENT_NAME, checkpoints_dirpath)
trainer = Trainer(
callbacks=callbacks,
logger=mlflow_logger,
max_epochs=5
)
print(f'ModelCheckpoint Callback dirpath: {checkpoints_dirpath}')
trainer.fit(model, datamodule=data_module)
trainer.test(model=model, datamodule=data_module)
if __name__ == "__main__":
cli_main()
Error messages and logs
The above code saves model checkpoints in the tracking_uri location of the MLFlowLogger even though checkpoints already exist in the directory from which I ran the script (which is where the ModelCheckpoint callbacks are saving it by default.
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
So, I did some digging around.
Based on my experiments, I observed that no matter what I do, ModelCheckpoint saves checkpoints to the directory I run the script from. Unless I give it a dirpath, then it saves the checkpoints in the dirpath.
Based on the source code, this makes sense:
- The
ModelCheckpointsets the value of its internal variableself.dirpaththrough its__init_ckpt_dir()function: https://github.com/Lightning-AI/lightning/blob/6eae2310d6dae086596e5bdddd08e8cd3884336e/src/lightning/pytorch/callbacks/model_checkpoint.py#L442 - And the
__init_ckpt_dir()function gets called in the__init__()function: https://github.com/Lightning-AI/lightning/blob/6eae2310d6dae086596e5bdddd08e8cd3884336e/src/lightning/pytorch/callbacks/model_checkpoint.py#L246
In any case, its impossible to get ModelCheckpoint to use a logger's save_dir as the save location. I thought this should be possible when I saw this docstring ro __resolve_ckpt_dir(): https://github.com/Lightning-AI/lightning/blob/6eae2310d6dae086596e5bdddd08e8cd3884336e/src/lightning/pytorch/callbacks/model_checkpoint.py#L580
But that resolution has no effect, because the ModelCheckpoint's self.dirpath value isalready set during __init__().
I think it might be worthwhile to remove/modify the __resolve_ckpt_dir() function in ModelCheckpoint since it has no effect. Also, we should remove any mention of coupling between loggers save_dir and ModelCheckpoint save locations in the documentation.
I am encountering something similar to this issue as well.
Using MLFlowLogger with log_model = True and a tracking_uri set to some HTTP URL results in the checkpoints being saved both as artifacts in the MLFlow tracking server and in subdirectories of whatever directory is configured as the trainer's root. This issue doesn't occur when using save_dir.
Also having this issue when trying to use the MLFlowLogger and a ModelCheckpoint callback together. Something in the MLFLowLogger overrides the dirpath specified in provided checkpoint callback. Specifying a save_dir when initializing the logger also didn't help. This is pretty unintuitive behavior, imo it's a bug.
I didn't trace through to see exactly where this happens, but I was able to work around it by modifying the trainer's checkpoint callback directly:
model = LightningModule(
callbacks = [ModelCheckpoint(...), ],
)
trainer = Trainer(
logger=MLFlowLogger,
)
trainer.checkpoint_callback.dirpath = my_desired_dirpath
trainer.fit( model=model)
This might be a little bit offtopic, but wouldn't it be better if the mlflowLogger completely took over the model check pointing behavior if it is used? ModelCheckpoint is currently designed in a way to save to a file, whereas the mlflow api is agnostic of files. Another way I could imagine, is to subclass ModelCheckpoint and change its behavior tor log to ml flow.
Something along the lines of this:
class MLflowModelCheckpoint(ModelCheckpoint):
def __init__(...):
super().__init__(...)
if not mlflow_run():
raise Exception('Not an MLFlow run')
# ...
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
super.on_train_start(...)
mlflow.register_model(...)
def _save_checkpoint(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
checkpoint = trainer.get_checkpoint(...) # This method currently doesnt exist
mlflow.log_model(checkpoint, ...)
I think the main conflict in using mlflow and lightning together is, that the Trainer currently not designed in a way, to be able to delegate checkpoint saving.
I also want to add that MLFlowLogger currently only copies checkpoints from a file: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loggers/mlflow.py#L335