sagemaker-python-sdk
sagemaker-python-sdk copied to clipboard
SageMaker Batch currently doesn't support Model entity with container definitions which use ModelDataSource attribute
Describe the feature you'd like
Batch Transform deployment to support ModelDataSource for LLM batch transform operations.
How would this feature be used? Please describe. A clear and concise description of the use case for this feature. Please provide an example, if possible.
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers
role = sagemaker.get_execution_role() # execution role for the endpoint
sess = sagemaker.session.Session() # sagemaker session for interacting with different AWS APIs
region = sess._region_name # region name of the current SageMaker Studio environment
account_id = sess.account_id()
from huggingface_hub import snapshot_download
from pathlib import Path
import os
# - This will download the model into the current directory where ever the jupyter notebook is running
local_model_path = Path(".")
local_model_path.mkdir(exist_ok=True)
model_name = 'mistralai/Mistral-7B-v0.1'
# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.txt", "*.model", "*.safetensors", "*.bin", "*.chk", "*.pth"]
# - Leverage the snapshot library to donload the model since the model is stored in repository using LFS
model_download_path = snapshot_download(
repo_id=model_name,
cache_dir=local_model_path,
allow_patterns=allow_patterns,
token='<HF TOKEN>'
)
%%writefile {model_download_path}/serving.properties
engine=Python
option.tensor_parallel_degree=max
option.model_id={{model_id}}
option.max_rolling_batch_size=16
option.rolling_batch=vllm
import jinja2
from pathlib import Path
jinja_env = jinja2.Environment()
template = jinja_env.from_string(Path("serving.properties").open().read())
Path("serving.properties").open("w").write(
template.render(model_id=base_model_s3_uri)
)
base_model_s3_uri = sess.upload_data(path=model_download_path, key_prefix="batch-transform-mistral/model")
print(f"Model uploaded to --- > {base_model_s3_uri}")
#https://github.com/aws/sagemaker-python-sdk/blob/master/tests/unit/test_djl_inference.py#L31-L33
image_uri = image_uris.retrieve(
framework="djl-lmi",
region=region,
version="0.28.0"
)
model_data = {
"S3DataSource": {
"S3Uri": f'{base_model_s3_uri}/',
'S3DataType': 'S3Prefix',
'CompressionType': 'None'
}
}
# create your SageMaker Model
model = sagemaker.Model(image_uri=image_uri, model_data=model_data, role=role)
from sagemaker.utils import name_from_base
endpoint_name = name_from_base("lmi-batch-transform-mistral-gated")
# instance type you will deploy your model to
instance_type = "ml.g5.12xlarge"
# Creating the batch transformer object. If you have a large dataset you can
# divide it into smaller chunks and use more instances for faster inference
batch_transformer = model.transformer(
instance_count=1,
instance_type=instance_type,
output_path=s3_output_data_path,
assemble_with="Line",
accept="text/csv",
max_payload=1,
)
batch_transformer.env = hyper_params_dict
# Making the predictions on the input data
batch_transformer.transform(
s3_input_data_path, content_type="application/jsonlines", split_type="Line"
)
batch_transformer.wait()
This throws the error:
---------------------------------------------------------------------------
ClientError Traceback (most recent call last)
Cell In[36], line 14
11 batch_transformer.env = hyper_params_dict
13 # Making the predictions on the input data
---> 14 batch_transformer.transform(
15 s3_input_data_path, content_type="application[/jsonlines](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/jsonlines)", split_type="Line"
16 )
18 batch_transformer.wait()
File [/opt/conda/lib/python3.10/site-packages/sagemaker/workflow/pipeline_context.py:346](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/workflow/pipeline_context.py#line=345), in runnable_by_pipeline.<locals>.wrapper(*args, **kwargs)
342 return context
344 return _StepArguments(retrieve_caller_name(self_instance), run_func, *args, **kwargs)
--> 346 return run_func(*args, **kwargs)
File [/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py:302](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py#line=301), in Transformer.transform(self, data, data_type, content_type, compression_type, split_type, job_name, input_filter, output_filter, join_source, experiment_config, model_client_config, batch_data_capture_config, wait, logs)
292 experiment_config = check_and_get_run_experiment_config(experiment_config)
294 batch_data_capture_config = resolve_class_attribute_from_config(
295 None,
296 batch_data_capture_config,
(...)
299 sagemaker_session=self.sagemaker_session,
300 )
--> 302 self.latest_transform_job = _TransformJob.start_new(
303 self,
304 data,
305 data_type,
306 content_type,
307 compression_type,
308 split_type,
309 input_filter,
310 output_filter,
311 join_source,
312 experiment_config,
313 model_client_config,
314 batch_data_capture_config,
315 )
317 if wait:
318 self.latest_transform_job.wait(logs=logs)
File [/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py:636](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py#line=635), in _TransformJob.start_new(cls, transformer, data, data_type, content_type, compression_type, split_type, input_filter, output_filter, join_source, experiment_config, model_client_config, batch_data_capture_config)
619 """Placeholder docstring"""
621 transform_args = cls._get_transform_args(
622 transformer,
623 data,
(...)
633 batch_data_capture_config,
634 )
--> 636 transformer.sagemaker_session.transform(**transform_args)
638 return cls(transformer.sagemaker_session, transformer._current_job_name)
File [/opt/conda/lib/python3.10/site-packages/sagemaker/session.py:3805](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/session.py#line=3804), in Session.transform(self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, input_config, output_config, resource_config, experiment_config, env, tags, data_processing, model_client_config, batch_data_capture_config)
3802 logger.debug("Transform request: %s", json.dumps(request, indent=4))
3803 self.sagemaker_client.create_transform_job(**request)
-> 3805 self._intercept_create_request(transform_request, submit, self.transform.__name__)
File [/opt/conda/lib/python3.10/site-packages/sagemaker/session.py:6497](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/session.py#line=6496), in Session._intercept_create_request(self, request, create, func_name)
6480 def _intercept_create_request(
6481 self,
6482 request: typing.Dict,
(...)
6485 # pylint: disable=unused-argument
6486 ):
6487 """This function intercepts the create job request.
6488
6489 PipelineSession inherits this Session class and will override
(...)
6495 func_name (str): the name of the function needed intercepting
6496 """
-> 6497 return create(request)
File [/opt/conda/lib/python3.10/site-packages/sagemaker/session.py:3803](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/session.py#line=3802), in Session.transform.<locals>.submit(request)
3801 logger.info("Creating transform job with name: %s", job_name)
3802 logger.debug("Transform request: %s", json.dumps(request, indent=4))
-> 3803 self.sagemaker_client.create_transform_job(**request)
File [/opt/conda/lib/python3.10/site-packages/botocore/client.py:565](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/botocore/client.py#line=564), in ClientCreator._create_api_method.<locals>._api_call(self, *args, **kwargs)
561 raise TypeError(
562 f"{py_operation_name}() only accepts keyword arguments."
563 )
564 # The "self" in this scope is referring to the BaseClient.
--> 565 return self._make_api_call(operation_name, kwargs)
File [/opt/conda/lib/python3.10/site-packages/botocore/client.py:1021](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/botocore/client.py#line=1020), in BaseClient._make_api_call(self, operation_name, api_params)
1017 error_code = error_info.get("QueryErrorCode") or error_info.get(
1018 "Code"
1019 )
1020 error_class = self.exceptions.from_code(error_code)
-> 1021 raise error_class(parsed_response, operation_name)
1022 else:
1023 return parsed_response
ClientError: An error occurred (ValidationException) when calling the CreateTransformJob operation: SageMaker Batch currently doesn't support Model entity with container definitions which use ModelDataSource attribute
Describe alternatives you've considered A clear and concise description of any alternative solutions or features you've considered.
Additional context Add any other context or screenshots about the feature request here.