metaflow
metaflow copied to clipboard
Amazon Sagemaker Training, Deployment, and Inference API's
Pull request for an initial implementation of streamlined API's for Amazon Sagemaker. Includes 3 functions: "fit", "deploy", and "predict". The primary algorithms this was built around were XGBoost and Linear Learner, but it should be compatible with any built in algorithm that accepts 'text/csv' as a content-type.
Two additional environment variables are required to run these flows.
- METAFLOW_SAGEMAKER_REGION is the region of your Sagemaker run.
- METAFLOW_SAGEMAKER_IAM_ROLE is a Sagemaker execution role with permissions to access all appropriate cloud resources, most notably S3.
Usage consists of from metaflow import Sagemaker
.
Sagemaker.fit(data, image, hyperparameters, stopping_condition, resource_config)
Returns a string object with the S3 URI of the model artifact generated by the fit.
-
data is a dictionary with keys that reference Sagemaker "channel names" found here, and values that consist of CSV data with no headers or indexes.
-
image is a string consisting of a Sagemaker built in algorithm container registry path, also found here. Automatic mapping will be coming soon.
-
hyperparameters is a dictionary with hyperparameters for the specific algorithm referenced by image. An example for XGBoost can be found here.
-
stopping_condition and resource_config are optional dictionaries for overriding some defaults, specifically that of a single ml.m4.xlarge training instance with a 5 GB volume, and a 1 hour max runtime. Syntax for those overrides can be found here
Sagemaker.deploy(model_uri, image, instanceType, instanceCount, instanceWeight, variantName)
Returns a string object with the endpoint name generated by the model deployment.
-
model_uri is a string object with the S3 path for the model. This string is returned by
Sagemaker.fit
-
image is a string object. It should be the same image used for training.
-
instanceType, instanceCount, instanceWeight, and variantName are all optional parameters for overriding the defaults of, respectively, "ml.m4.xlarge", 1, 1, and "AllTraffic".
Sagemaker.predict(data, endpoint_name)
Returns a list of predictions.
-
data is a CSV object with no headers or indexes representing the features for inference.
-
endpoint_name is a string object with the Sagemaker endpoint to be inferred against. This value is returned by
Sagemaker.deploy
.
A short example of the usage can be found here. This PR also brings in an 08-sagemaker tutorial for use with metaflow tutorials pull
that demonstrates the above sample flow.
@queueburt Thanks for this great proposal! Do you have any news on the topic?
Closing this PR in favor of native support for hosting models on Sagemaker with Metaflow