yocto-gl
yocto-gl copied to clipboard
[FR] Why is there no 'log_model()' function in mlflow.client?
Willingness to contribute
Yes. I would be willing to contribute this feature with guidance from the MLflow community.
Proposal Summary
I want to use mlflow by parallel runs. But, mlflow is not thread-safe. when using mlflow.start_run()
code, run id becomes a global variable. So, every parallel runs become crashed.
I found a method of log_artifact()
in mlflow.client()
. This method is not log model and meta info of model to mlflow tracking server, but only save to registry server.
I want to method for mlflow.client
like mlflow.skelarn.log_model()
that is save model to registry and log to mlflow tracking server.
Is it possible?
Motivation
What is the use case for this feature?
When using mlflow.client
, There is no way to log model to mlflow tracking server. log_artifact()
method in mlflow.client
is only save to registry not log.
Why is this use case valuable to support for MLflow users in general?
This feature is necessary to users who want to use mlflow in parallel.
Why is this use case valuable to support for your project(s) or organization?
Why is it currently difficult to achieve this use case?
Now I',m using mlflow.onnx.log_model()
. So parallel runs become crashed.
Details
No response
What component(s) does this bug affect?
- [X]
area/artifacts
: Artifact stores and artifact logging - [ ]
area/build
: Build and test infrastructure for MLflow - [ ]
area/docs
: MLflow documentation pages - [ ]
area/examples
: Example code - [X]
area/model-registry
: Model Registry service, APIs, and the fluent client calls for Model Registry - [X]
area/models
: MLmodel format, model serialization/deserialization, flavors - [ ]
area/recipes
: Recipes, Recipe APIs, Recipe configs, Recipe Templates - [ ]
area/projects
: MLproject format, project running backends - [ ]
area/scoring
: MLflow Model server, model deployment tools, Spark UDFs - [ ]
area/server-infra
: MLflow Tracking server backend - [ ]
area/tracking
: Tracking Service, tracking client APIs, autologging
What interface(s) does this bug affect?
- [ ]
area/uiux
: Front-end, user experience, plotting, JavaScript, JavaScript dev server - [ ]
area/docker
: Docker use across MLflow's components, such as MLflow Projects and MLflow Models - [ ]
area/sqlalchemy
: Use of SQLAlchemy in the Tracking Service or Model Registry - [ ]
area/windows
: Windows support
What language(s) does this bug affect?
- [ ]
language/r
: R APIs and clients - [ ]
language/java
: Java APIs and clients - [ ]
language/new
: Proposals for new client languages
What integration(s) does this bug affect?
- [ ]
integrations/azure
: Azure and Azure ML integrations - [ ]
integrations/sagemaker
: SageMaker integrations - [ ]
integrations/databricks
: Databricks integrations
I think we can add a method MLflowClient.log_model
, like:
class MLflowClient:
def log_model(
run_id,
artifact_path,
flavor,
registered_model_name=None,
await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
**kwargs,
)
@dbczumar @harupy @BenWilson2 WDYT ?
Can we just add run_id
(optional arguments that defaults to None
) to mlflow.<flavor>.log_model
?
@WeichenXu123
As I tested, after adding log_model()
function to mlflow.client, like this.
# This function referd from mlflow.models.model.log()
@classmethod
def log(
cls,
run_id,
local_path,
artifact_path,
flavor,
registered_model_name=None,
await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
**kwargs,
):
with TempDir() as tmp:
mlflow_model = cls(artifact_path=artifact_path, run_id=run_id)
flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)
mlflow.tracking.fluent.log_artifacts(local_path, mlflow_model.artifact_path)
try:
mlflow.tracking.fluent._record_logged_model(mlflow_model)
except MlflowException:
# We need to swallow all mlflow exceptions to maintain backwards compatibility with
# older tracking servers. Only print out a warning for now.
_logger.warning(_LOG_MODEL_METADATA_WARNING_TEMPLATE, mlflow.get_artifact_uri())
_logger.debug("", exc_info=True)
if registered_model_name is not None:
# run_id = mlflow.tracking.fluent.active_run().info.run_id
mlflow.register_model(
"runs:/%s/%s" % (run_id, mlflow_model.artifact_path),
registered_model_name,
await_registration_for=await_registration_for,
)
return mlflow_model.get_model_info()
Also, exceptions were handled for type check errors that occurred in the mlflow.tracking.fluent._record_logged_model(mlflow_model)
code.
# This function refered from mlflow.tracking.cllient._record_logged_model()
def _record_logged_model(self, run_id, mlflow_model):
from mlflow.client import MlflowClient
if not isinstance(mlflow_model, MlflowClient):
raise TypeError(
"Argument 'mlflow_model' should be of type mlflow.models.Model but was "
"{}".format(type(mlflow_model))
)
self.store.record_logged_model(run_id, mlflow_model)
I tested it and the above settings work.
@harupy
Can we just add
run_id
(optional arguments that defaults toNone
) tomlflow.<flavor>.log_model
?
This is a bit of using MLflowClient, because MlflowClient might use another tracking_uri
and registry_uri
that are different with the globally set tracking_uri / registry_uri
@jaehyeongAN
Your testing code uses classmethod
, so, it has similar issue , it cannot use the correct tracking_uri / registry_uri that are set in a specific MlflowClient instance.
@WeichenXu123 Got it.
I think we can add a method MLflowClient.log_model, like:
This approach sounds good to me.
Sounds good to me as well!
@jaehyeongAN
We get consensus to add API like:
class MLflowClient:
def log_model(
run_id,
artifact_path,
flavor,
registered_model_name=None,
await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
**kwargs,
)
Would you contribute this new API ?
Thank you!
@WeichenXu123 Cool, I want to do that.
@BenWilson2 @dbczumar @harupy @WeichenXu123 Please assign a maintainer and start triaging this issue.
@jaehyeongAN @WeichenXu123 any updates on this? This would be really nice to have :)
@jaehyeongAN I assigned the task to you. :)
@jaehyeongAN @WeichenXu123 I understand when the plate gets full in life and work. I can try to take this on as it's applicable to my work if that's okay with y'all. I just may need some guidance as it would be my first time contributing to MLflow.
On a related note, not only is run_id
global but IIRC register_model
would use the global tracking and registry URIs. I remember a couple of months ago needing to use mlflow.set_tracking_uri
and mlflow.set_registry_uri
with MLflow 2.4.0. Is this still the case and something we'd have to take into account when implementing MLflowClient.log_model
?
I had totally forgot about this issue until recently... Since there has been no activity since, I went ahead and created a WIP PR here: https://github.com/mlflow/mlflow/pull/11906
@WeichenXu123 following up on the above WIP PR that may warrant discussion before I proceed further. Thanks!