sagemaker-spark
sagemaker-spark copied to clipboard
SagemakerModel.transform() doesn't use model's sagemakerClient
If you create a new SageMakerModel
instance (say, with fromModelS3Path()
), you can pass in your own sagemakerClient
. However, when you go to use the model after it's been created, transform()
does not use that client to send prediction requests. It appears to hardcode a AmazonSageMakerRuntimeClientBuilder.defaultClient
instead in RequestBatchIterator
.
Pardon my ignorance, but is there a reason that it can't just pass the sagemakerClient
through?
Hi @harthur,
Thanks for using Amazon SageMaker!
There are two SageMaker clients: the AmazonSageMaker
client which is used to create and manage Training Jobs, Endpoints and such, and the AmazonSageMakerRuntime
which is just used for predictions (with InvokeEndpointRequest
in transform()
).
Instead of injecting this client, you have to change the value of this var in the RequestBatchIterator
singleton:
RequestBatchIterator. sagemakerRuntime = mySageMakerRuntimeClient
Why? Because Spark has to serialize tasks to send them to workers in the mapPartition()
call in SageMakerModel.transform()
, and these AWS clients aren't serializable. So instead of serializing the RequestBatchIterator
directly, we serialize a factory method that creates a RequestBatchIterator
.
https://github.com/aws/sagemaker-spark/blob/master/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/SageMakerModel.scala#L509
https://github.com/aws/sagemaker-spark/blob/master/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/transformation/util/RequestBatchIterator.scala#L228-L239
Please feel free to reopen this if it doesn't answer your question. Thanks!
That does answer my question, thanks!
However, how do you set RequestBatchIterator.sagemakerRuntime
without first unintentionally building the default client? Before I can override it, this line: https://github.com/aws/sagemaker-spark/blob/master/sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/transformation/util/RequestBatchIterator.scala#L34 I believe is building the client and looking for env vars. Getting this error: Caused by: com.amazonaws.SdkClientException: Unable to find a region via the region provider chain. Must provide an explicit region in the builder or setup environment to supply a region.
Let me know if this isn't the right place to ask more questions.
Hey @harthur,
I'm not sure exactly when the client is run. It's possible we should make that a lazy val
or otherwise delay instantiation. Do you have a stack trace / code you're running where the client is instantiated before you can set it?
Otherwise: you're right about the client builder -- to get around that error, you can set AWS_DEFAULT_REGION to your region (like us-west-2 or us-east-1 or us-east-2 or eu-west-1, any of the SageMaker regions) or set it in your AWS config file with aws configure
.
Thanks!
The stacktrace is
Exception in thread "main" java.lang.ExceptionInInitializerError
at TestSagemakerJob.run(TestSagemakerJob.scala:59)
...
Caused by: com.amazonaws.SdkClientException: Unable to find a region via the region provider chain. Must provide an explicit region in the builder or setup environment to supply a region.
at com.amazonaws.client.builder.AwsClientBuilder.setRegion(AwsClientBuilder.java:371)
at com.amazonaws.client.builder.AwsClientBuilder.configureMutableProperties(AwsClientBuilder.java:337)
at com.amazonaws.client.builder.AwsSyncClientBuilder.build(AwsSyncClientBuilder.java:46)
at com.amazonaws.services.sagemakerruntime.AmazonSageMakerRuntimeClientBuilder.defaultClient(AmazonSageMakerRuntimeClientBuilder.java:44)
at com.amazonaws.services.sagemaker.sparksdk.transformation.util.RequestBatchIterator$.<init>(RequestBatchIterator.scala:35)
at com.amazonaws.services.sagemaker.sparksdk.transformation.util.RequestBatchIterator$.<clinit>(RequestBatchIterator.scala)
Where that line is:
RequestBatchIterator.sagemakerRuntime = sagemakerRuntimeClient
Hi @harthur ,
Thanks for the stacktrace! Just FYI: I haven't gotten a chance to reproduce this yet, but this definitely seems like a bug. I suppose that workers are still trying to create the standard client. If you're able to, could you post your code?
Otherwise, to unblock yourself in the short term, it seems like you'll have to get the region from the environment in your workers (by setting AWS_DEFAULT_REGION
or writing to ~/.aws/configure first).
Thanks again!
Yeah, I got around it by adding some Java system properties, but ideally you would be able to build your own client and keep that info isolated. That's the pattern we use for all of our other AWS connections, so it's a bit awkward to break that just for Sagemaker.
Sorry, to address your first question, I think it's happening in the driver rather than on any workers (or, that's where this particular line of code is). It happens just by instantiating RequestBatchIterator
.
Hey @harthur ,
Ah, interesting, thanks for the update! Glad to hear you got it working, but you're right, we should let users build their own client. I've put a fix for this on our backlog, thanks for reporting this.
We'll keep this issue open and update this when the fix is in, but I can't give an ETA on when we'll be able to do this.
Thanks again!
Hi @harthur , I am also facing this same issue. Created a model from End point and calling transform gives this error. Can you please share the java system properties that helped you get around?
Thanks
@sreemani You have to set these system properties: https://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html#credentials-default
Any update on this issue? This bug has been open since 2018, and it is causing our team some problems.