sagemaker-python-sdk icon indicating copy to clipboard operation
sagemaker-python-sdk copied to clipboard

Can't register a model in model registry without specifying `inference_instances` and `transform_instances`

Open acere opened this issue 2 years ago • 0 comments

Describe the bug Registering a model using SageMaker SDK model.register() without specifying inference_instances and transform_instances fails with the following error:

ParamValidationError: Parameter validation failed:
Invalid type for parameter InferenceSpecification.SupportedRealtimeInferenceInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>
Invalid type for parameter InferenceSpecification.SupportedTransformInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>

The same operation via boto3, sagemaker_client.create_model_package() instead completes successfully.

I traced the issue to these two lines of code: https://github.com/aws/sagemaker-python-sdk/blob/9369a8781da9be92e1850ff62b42428cd61e23a6/src/sagemaker/session.py#L4502-L4503

When inference_instances and transform_instances are None in model.register(), they should not be included at all to the inference_specification dictionary.

To reproduce The bug can be reproduced running this code block in a SageMaker (studio) notebook:

import sagemaker

model_data = "test.tar.gz"
model_package_group_name = "test-model-group-name"

!touch {model_data}

model_data_uri = sagemaker.s3.S3Uploader.upload(
    local_path=model_data,
    desired_s3_uri=f"s3://{sagemaker.Session().default_bucket()}/{model_package_group_name}",
)


inference_image = sagemaker.image_uris.retrieve(
    framework="xgboost",
    region=sagemaker.Session().boto_region_name,
    image_scope="inference",
    version="latest",
)

model = sagemaker.Model(
    image_uri=inference_image,
    model_data=model_data_uri,
    sagemaker_session=sagemaker.Session(),
)

model.register(
    content_types=["application/json"],
    response_types=["application/json"],
    model_package_group_name=model_package_group_name,
)

Expected behavior I would expect the model to be registered

Screenshots or logs

this is the error returned when running the code block above

---------------------------------------------------------------------------
ParamValidationError                      Traceback (most recent call last)
Input In [2], in <cell line: 27>()
     14 inference_image = sagemaker.image_uris.retrieve(
     15     framework="xgboost",
     16     region=sagemaker.Session().boto_region_name,
     17     image_scope="inference",
     18     version="latest",
     19 )
     21 model = sagemaker.Model(
     22     image_uri=inference_image,
     23     model_data=model_data_uri,
     24     sagemaker_session=sagemaker.Session(),
     25 )
---> 27 model.register(
     28     content_types=["application/json"],
     29     response_types=["application/json"],
     30     model_package_group_name=model_package_group_name,
     31 )

File /opt/conda/lib/python3.8/site-packages/sagemaker/workflow/pipeline_context.py:209, in runnable_by_pipeline.<locals>.wrapper(*args, **kwargs)
    206     run_func(*args, **kwargs)
    207     return self_instance.sagemaker_session.context
--> 209 return run_func(*args, **kwargs)

File /opt/conda/lib/python3.8/site-packages/sagemaker/model.py:409, in Model.register(self, content_types, response_types, inference_instances, transform_instances, model_package_name, model_package_group_name, image_uri, model_metrics, metadata_properties, marketplace_cert, approval_status, description, drift_check_baselines, customer_metadata_properties, validation_specification, domain, task, sample_payload_url, framework, framework_version, nearest_model_name, data_input_configuration)
    384     container_def = {
    385         "Image": self.image_uri,
    386         "ModelDataUrl": self.model_data,
    387     }
    389 model_pkg_args = sagemaker.get_model_package_args(
    390     content_types,
    391     response_types,
   (...)
    407     task=task,
    408 )
--> 409 model_package = self.sagemaker_session.create_model_package_from_containers(
    410     **model_pkg_args
    411 )
    412 if isinstance(self.sagemaker_session, PipelineSession):
    413     return None

File /opt/conda/lib/python3.8/site-packages/sagemaker/session.py:2896, in Session.create_model_package_from_containers(self, containers, content_types, response_types, inference_instances, transform_instances, model_package_name, model_package_group_name, model_metrics, metadata_properties, marketplace_cert, approval_status, description, drift_check_baselines, customer_metadata_properties, validation_specification, domain, sample_payload_url, task)
   2891             self.sagemaker_client.create_model_package_group(
   2892                 ModelPackageGroupName=request["ModelPackageGroupName"]
   2893             )
   2894     return self.sagemaker_client.create_model_package(**request)
-> 2896 return self._intercept_create_request(
   2897     model_pkg_request, submit, self.create_model_package_from_containers.__name__
   2898 )

File /opt/conda/lib/python3.8/site-packages/sagemaker/session.py:4230, in Session._intercept_create_request(self, request, create, func_name)
   4217 def _intercept_create_request(
   4218     self, request: typing.Dict, create, func_name: str = None  # pylint: disable=unused-argument
   4219 ):
   4220     """This function intercepts the create job request.
   4221 
   4222     PipelineSession inherits this Session class and will override
   (...)
   4228         func_name (str): the name of the function needed intercepting
   4229     """
-> 4230     return create(request)

File /opt/conda/lib/python3.8/site-packages/sagemaker/session.py:2894, in Session.create_model_package_from_containers.<locals>.submit(request)
   2890     except ClientError:
   2891         self.sagemaker_client.create_model_package_group(
   2892             ModelPackageGroupName=request["ModelPackageGroupName"]
   2893         )
-> 2894 return self.sagemaker_client.create_model_package(**request)

File /opt/conda/lib/python3.8/site-packages/botocore/client.py:508, in ClientCreator._create_api_method.<locals>._api_call(self, *args, **kwargs)
    504     raise TypeError(
    505         f"{py_operation_name}() only accepts keyword arguments."
    506     )
    507 # The "self" in this scope is referring to the BaseClient.
--> 508 return self._make_api_call(operation_name, kwargs)

File /opt/conda/lib/python3.8/site-packages/botocore/client.py:874, in BaseClient._make_api_call(self, operation_name, api_params)
    865     logger.debug(
    866         'Warning: %s.%s() is deprecated', service_name, operation_name
    867     )
    868 request_context = {
    869     'client_region': self.meta.region_name,
    870     'client_config': self.meta.config,
    871     'has_streaming_input': operation_model.has_streaming_input,
    872     'auth_type': operation_model.auth_type,
    873 }
--> 874 request_dict = self._convert_to_request_dict(
    875     api_params, operation_model, context=request_context
    876 )
    877 resolve_checksum_context(request_dict, operation_model, api_params)
    879 service_id = self._service_model.service_id.hyphenize()

File /opt/conda/lib/python3.8/site-packages/botocore/client.py:935, in BaseClient._convert_to_request_dict(self, api_params, operation_model, context)
    929 def _convert_to_request_dict(
    930     self, api_params, operation_model, context=None
    931 ):
    932     api_params = self._emit_api_params(
    933         api_params, operation_model, context
    934     )
--> 935     request_dict = self._serializer.serialize_to_request(
    936         api_params, operation_model
    937     )
    938     if not self._client_config.inject_host_prefix:
    939         request_dict.pop('host_prefix', None)

File /opt/conda/lib/python3.8/site-packages/botocore/validate.py:381, in ParamValidationDecorator.serialize_to_request(self, parameters, operation_model)
    377     report = self._param_validator.validate(
    378         parameters, operation_model.input_shape
    379     )
    380     if report.has_errors():
--> 381         raise ParamValidationError(report=report.generate_report())
    382 return self._serializer.serialize_to_request(
    383     parameters, operation_model
    384 )

ParamValidationError: Parameter validation failed:
Invalid type for parameter InferenceSpecification.SupportedRealtimeInferenceInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>
Invalid type for parameter InferenceSpecification.SupportedTransformInstanceTypes, value: None, type: <class 'NoneType'>, valid types: <class 'list'>, <class 'tuple'>

System information A description of your system. Please provide:

  • SageMaker Python SDK version: 2.99.0
  • Framework name (eg. PyTorch) or algorithm (eg. KMeans): Any
  • Framework version: Any
  • Python version: 3.8
  • CPU or GPU: Any
  • Custom Docker image (Y/N): N Additional context

acere avatar Jul 10 '22 05:07 acere