sagemaker-spark
sagemaker-spark copied to clipboard
What is the correct way to construct a ProtobufResponseRowDeserializer in PySpark?
Please fill out the form below.
System Information
- Spark or PySpark: PySpark
- SDK Version: 2.3.4
- Spark Version:
- Algorithm (e.g. KMeans): Random Cut Forest Estimator
Describe the problem
I have the following code in pyspark trying to to construct a SageMakerEstimator
for a random cut forest image:
# Random Cut Forest Estimator
from pyspark.sql.types import *
from sagemaker_pyspark import IAMRole
from sagemaker import get_execution_role
from sagemaker_pyspark import SageMakerEstimator
from sagemaker_pyspark import RandomNamePolicyFactory
from sagemaker_pyspark import EndpointCreationPolicy
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker_pyspark.transformation.serializers.serializers import ProtobufRequestRowSerializer
from sagemaker_pyspark.transformation.deserializers.deserializers import ProtobufResponseRowDeserializer
response_schema = StructType([StructField("score", DoubleType(), False)])
estimator = SageMakerEstimator(
trainingImage = get_image_uri(region, 'randomcutforest'), # Training image
modelImage = get_image_uri(region, 'randomcutforest'), # Model image
requestRowSerializer = ProtobufRequestRowSerializer(featuresColumnName="features"),
responseRowDeserializer = ProtobufResponseRowDeserializer(response_schema, protobufKeys["score"]),
sagemakerRole = IAMRole(role),
hyperParameters = {"feature_dim": "6"},
trainingInstanceType = "ml.m4.4xlarge",
trainingInstanceCount = 1,
endpointInstanceType = "ml.t2.medium",
endpointInitialInstanceCount = 1,
trainingSparkDataFormat = "sagemaker",
namePolicyFactory = RandomNamePolicyFactory("sparksm-4-"),
endpointCreationPolicy = EndpointCreationPolicy.CREATE_ON_CONSTRUCT
)
When I run this code using PySpark, I got the following error:
Py4JError: An error occurred while calling None.com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.ProtobufResponseRowDeserializer. Trace:
py4j.Py4JException: Constructor com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.ProtobufResponseRowDeserializer([class org.apache.spark.sql.types.StructType, class scala.collection.immutable.$colon$colon]) does not exist
at py4j.reflection.ReflectionEngine.getConstructor(ReflectionEngine.java:179)
at py4j.reflection.ReflectionEngine.getConstructor(ReflectionEngine.java:196)
at py4j.Gateway.invoke(Gateway.java:237)
at py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)
at py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:745)
The problem is in the ProtobufResponseRowDeserializer
. According to the source code of this object for Scala, it should accept a Seq
.
What is the correct counterpart in PySpark? Obviously it doesn't accept a list of string.
I tried to search the sagemaker-spark-sdk
and I couldn't find any reference there.
Hi @RunshengSong, to use ProtobufResponseRowDeserializer
with pyspark-sdk, the constructor accepts a StructType
instead of a string: https://github.com/aws/sagemaker-spark/blob/master/sagemaker-pyspark-sdk/src/sagemaker_pyspark/transformation/deserializers/deserializers.py#L30
You will need to build a StructType
(https://spark.apache.org/docs/1.1.1/api/python/pyspark.sql.StructType-class.html) that contains the feature column field and feed it to the ProtobufResponseRowDeserializer
constructor.
Hi @ChuyangDeng , thanks for the reply. I understand that I need to send a StructType
as the schema to ProtobufResponseRowDeserializer
, which is already the case in the code I provide above.
However, the problem I was asking is the protobufKeys attribute. When I don't send this parameter, it gives me an NPE when I display the Dataframe of prediction output.
What should be correct type of protobufKeys attribute?
Thanks again.