composer icon indicating copy to clipboard operation
composer copied to clipboard

Enable control over MLFlowLogger run_name str to match a pre-existing tag run_name in MLFlow and resume model training

Open wolliq opened this issue 1 year ago • 1 comments

🚀 Feature Request

We would like to resume a model training, passing the run_name from YAML, using the MLFlowLogger.

Motivation

Today MLFlowLogger receives the run_name string from the YAML config but it has no control over it as the str automatically append a random str to it, i.e. my-test => my-test-sgftKr at runtime.

In the MLFlowLogger docs:

        run_name: (str, optional): MLflow run name. If not set it will be the same as the Trainer run name

but it always gets overridden by the random value after YAML parsing.

In the MLFlowLogger we have the filter string that captures the passed run_name randomly generated and it will not possible to match with a pre-existing run:

    def _start_mlflow_run(self, state):
        import mlflow

        env_run_id = os.getenv(
            mlflow.environment_variables.MLFLOW_RUN_ID.name,  # pyright: ignore[reportGeneralTypeIssues]
            None,
        )
        if env_run_id is not None:
            self._run_id = env_run_id
        elif self.resume:
            # Search for an existing run tagged with this Composer run if `self.resume=True`.
            assert self._experiment_id is not None
            run_name = self.tags['run_name']
            existing_runs = mlflow.search_runs(
                experiment_ids=[self._experiment_id],
                filter_string=f'tags.run_name = "{run_name}"',    # <<< HERE
                output_format='list',
            )
...

As explained in the {run_name} we will always find a random str appended to it for each new run.

[Optional] Implementation

Possible solution could be disable the random string generation by defining another environmental variable during YAML parsing, such as:

mlflow_tag_run_name=True

so that when the resume action is called, the run name is given to match the tag run name in MLFlow,

or directly

mlflow_tag_run_name="my-run-asdasd"

so that the str run_name is passed as is to MLFlowLogger to handle the resume.

Additional context

This is for a use case where we run training on the MosaicML platform and we log into MLFlow on Databricks platform. Checkpointing is working fine, but the loss logging is wrong and separated because the unmatch of the random run_name force MLFlow to create a new run id for the resumed training.

wolliq avatar May 07 '24 14:05 wolliq

@wolliq woud you mind sharing your YAML? are you saying you directly pass run name to mlflow logger but it is always overridden?

mvpatel2000 avatar May 07 '24 21:05 mvpatel2000