sagemaker-python-sdk
sagemaker-python-sdk copied to clipboard
Can't register a model in model registry without specifying `inference_instances` and `transform_instances`
Describe the bug
Registering a model using SageMaker SDK model.register()
without specifying inference_instances
and transform_instances
fails with the following error:
ParamValidationError: Parameter validation failed:
Invalid type for parameter InferenceSpecification.SupportedRealtimeInferenceInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>
Invalid type for parameter InferenceSpecification.SupportedTransformInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>
The same operation via boto3, sagemaker_client.create_model_package()
instead completes successfully.
I traced the issue to these two lines of code: https://github.com/aws/sagemaker-python-sdk/blob/9369a8781da9be92e1850ff62b42428cd61e23a6/src/sagemaker/session.py#L4502-L4503
When inference_instances
and transform_instances
are None
in model.register()
, they should not be included at all to the inference_specification
dictionary.
To reproduce The bug can be reproduced running this code block in a SageMaker (studio) notebook:
import sagemaker
model_data = "test.tar.gz"
model_package_group_name = "test-model-group-name"
!touch {model_data}
model_data_uri = sagemaker.s3.S3Uploader.upload(
local_path=model_data,
desired_s3_uri=f"s3://{sagemaker.Session().default_bucket()}/{model_package_group_name}",
)
inference_image = sagemaker.image_uris.retrieve(
framework="xgboost",
region=sagemaker.Session().boto_region_name,
image_scope="inference",
version="latest",
)
model = sagemaker.Model(
image_uri=inference_image,
model_data=model_data_uri,
sagemaker_session=sagemaker.Session(),
)
model.register(
content_types=["application/json"],
response_types=["application/json"],
model_package_group_name=model_package_group_name,
)
Expected behavior I would expect the model to be registered
Screenshots or logs
this is the error returned when running the code block above
---------------------------------------------------------------------------
ParamValidationError Traceback (most recent call last)
Input In [2], in <cell line: 27>()
14 inference_image = sagemaker.image_uris.retrieve(
15 framework="xgboost",
16 region=sagemaker.Session().boto_region_name,
17 image_scope="inference",
18 version="latest",
19 )
21 model = sagemaker.Model(
22 image_uri=inference_image,
23 model_data=model_data_uri,
24 sagemaker_session=sagemaker.Session(),
25 )
---> 27 model.register(
28 content_types=["application/json"],
29 response_types=["application/json"],
30 model_package_group_name=model_package_group_name,
31 )
File /opt/conda/lib/python3.8/site-packages/sagemaker/workflow/pipeline_context.py:209, in runnable_by_pipeline.<locals>.wrapper(*args, **kwargs)
206 run_func(*args, **kwargs)
207 return self_instance.sagemaker_session.context
--> 209 return run_func(*args, **kwargs)
File /opt/conda/lib/python3.8/site-packages/sagemaker/model.py:409, in Model.register(self, content_types, response_types, inference_instances, transform_instances, model_package_name, model_package_group_name, image_uri, model_metrics, metadata_properties, marketplace_cert, approval_status, description, drift_check_baselines, customer_metadata_properties, validation_specification, domain, task, sample_payload_url, framework, framework_version, nearest_model_name, data_input_configuration)
384 container_def = {
385 "Image": self.image_uri,
386 "ModelDataUrl": self.model_data,
387 }
389 model_pkg_args = sagemaker.get_model_package_args(
390 content_types,
391 response_types,
(...)
407 task=task,
408 )
--> 409 model_package = self.sagemaker_session.create_model_package_from_containers(
410 **model_pkg_args
411 )
412 if isinstance(self.sagemaker_session, PipelineSession):
413 return None
File /opt/conda/lib/python3.8/site-packages/sagemaker/session.py:2896, in Session.create_model_package_from_containers(self, containers, content_types, response_types, inference_instances, transform_instances, model_package_name, model_package_group_name, model_metrics, metadata_properties, marketplace_cert, approval_status, description, drift_check_baselines, customer_metadata_properties, validation_specification, domain, sample_payload_url, task)
2891 self.sagemaker_client.create_model_package_group(
2892 ModelPackageGroupName=request["ModelPackageGroupName"]
2893 )
2894 return self.sagemaker_client.create_model_package(**request)
-> 2896 return self._intercept_create_request(
2897 model_pkg_request, submit, self.create_model_package_from_containers.__name__
2898 )
File /opt/conda/lib/python3.8/site-packages/sagemaker/session.py:4230, in Session._intercept_create_request(self, request, create, func_name)
4217 def _intercept_create_request(
4218 self, request: typing.Dict, create, func_name: str = None # pylint: disable=unused-argument
4219 ):
4220 """This function intercepts the create job request.
4221
4222 PipelineSession inherits this Session class and will override
(...)
4228 func_name (str): the name of the function needed intercepting
4229 """
-> 4230 return create(request)
File /opt/conda/lib/python3.8/site-packages/sagemaker/session.py:2894, in Session.create_model_package_from_containers.<locals>.submit(request)
2890 except ClientError:
2891 self.sagemaker_client.create_model_package_group(
2892 ModelPackageGroupName=request["ModelPackageGroupName"]
2893 )
-> 2894 return self.sagemaker_client.create_model_package(**request)
File /opt/conda/lib/python3.8/site-packages/botocore/client.py:508, in ClientCreator._create_api_method.<locals>._api_call(self, *args, **kwargs)
504 raise TypeError(
505 f"{py_operation_name}() only accepts keyword arguments."
506 )
507 # The "self" in this scope is referring to the BaseClient.
--> 508 return self._make_api_call(operation_name, kwargs)
File /opt/conda/lib/python3.8/site-packages/botocore/client.py:874, in BaseClient._make_api_call(self, operation_name, api_params)
865 logger.debug(
866 'Warning: %s.%s() is deprecated', service_name, operation_name
867 )
868 request_context = {
869 'client_region': self.meta.region_name,
870 'client_config': self.meta.config,
871 'has_streaming_input': operation_model.has_streaming_input,
872 'auth_type': operation_model.auth_type,
873 }
--> 874 request_dict = self._convert_to_request_dict(
875 api_params, operation_model, context=request_context
876 )
877 resolve_checksum_context(request_dict, operation_model, api_params)
879 service_id = self._service_model.service_id.hyphenize()
File /opt/conda/lib/python3.8/site-packages/botocore/client.py:935, in BaseClient._convert_to_request_dict(self, api_params, operation_model, context)
929 def _convert_to_request_dict(
930 self, api_params, operation_model, context=None
931 ):
932 api_params = self._emit_api_params(
933 api_params, operation_model, context
934 )
--> 935 request_dict = self._serializer.serialize_to_request(
936 api_params, operation_model
937 )
938 if not self._client_config.inject_host_prefix:
939 request_dict.pop('host_prefix', None)
File /opt/conda/lib/python3.8/site-packages/botocore/validate.py:381, in ParamValidationDecorator.serialize_to_request(self, parameters, operation_model)
377 report = self._param_validator.validate(
378 parameters, operation_model.input_shape
379 )
380 if report.has_errors():
--> 381 raise ParamValidationError(report=report.generate_report())
382 return self._serializer.serialize_to_request(
383 parameters, operation_model
384 )
ParamValidationError: Parameter validation failed:
Invalid type for parameter InferenceSpecification.SupportedRealtimeInferenceInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>
Invalid type for parameter InferenceSpecification.SupportedTransformInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>
System information A description of your system. Please provide:
- SageMaker Python SDK version: 2.99.0
- Framework name (eg. PyTorch) or algorithm (eg. KMeans): Any
- Framework version: Any
- Python version: 3.8
- CPU or GPU: Any
- Custom Docker image (Y/N): N Additional context