moto icon indicating copy to clipboard operation
moto copied to clipboard

Sagemaker runtime invoke endpoint with custom results

Open LiMuBei opened this issue 9 months ago • 7 comments

I've been trying to mock a call to Sagemaker invoke_endpoint with getting a custom response. Documentation states, that one should do a post request in order to add a custom response. This to me seems to require running the Moto server. So far, simple unit test usage with the decorator has been enough, though. I successfully received the static dummy data, but that is of very limited use. Looking at the code, I could not find any other way to add any custom responses to the Sagemaker runtime mock, would that be correct?

LiMuBei avatar Mar 28 '25 15:03 LiMuBei

Hi @LiMuBei, if you make the POST request to http://motoapi.amazonaws.com/moto-api/static/sagemaker/endpoint-results, the Moto decorator should intercept it and calls to invoke_endpoint should return the custom response.

If it doesn't work with that URL, please share a reproducible example, and we can have a closer look.

bblommers avatar Mar 30 '25 11:03 bblommers

Hi @bblommers and thanks for your reply.

Here's what I am trying, boiled down to a simple example. Basically, I want to have a test for a Lambda handler, which internally makes the call to the Sagemaker endpoint.

The Lambda code:

import json
import os
import boto3


sagemaker_runtime = boto3.client("sagemaker-runtime")


def handler(event, context):
    input_payload = json.dumps({"bucket": "my_bucket", "key": "my_image"})
    # Invoke SageMaker endpoint
    response = sagemaker_runtime.invoke_endpoint(
        EndpointName=os.environ["SAGEMAKER_INFERENCE_ENDPOINT_NAME"],
        ContentType="application/json",
        Body=input_payload,
    )

    print(response)

And the test code:

import os
from unittest.mock import Mock
from moto import mock_aws
import pytest
import requests


@pytest.fixture
def context():
    return Mock(
        aws_request_id="test-request-id",
        log_stream_name="test-log-stream",
        log_group_name="test-log-group",
        function_name="test-function",
        function_version="test-version",
        invoked_function_arn="test-arn",
        memory_limit_in_mb=128,
        get_remaining_time_in_millis=Mock(return_value=5000),
    )


@pytest.fixture
def event():
    return {}


@mock_aws
def test_working_case(monkeypatch, event, context):
    monkeypatch.setenv("AWS_ACCOUNT_ID", "1234567890")
    monkeypatch.setenv("AWS_REGION", "us-east-1")
    monkeypatch.setenv("SAGEMAKER_INFERENCE_ENDPOINT_NAME", "sm-endpoint")

    from lambdas.sagemaker_dummy import handler

    # Setup Sagemaker mock response
    expected_results = {
        "account_id": os.environ["AWS_ACCOUNT_ID"],
        "region": os.environ["AWS_REGION"],
        "results": [
            {
                "Body": "foo",
                "ContentType": "application/json",
                "InvokedProductionVariant": "prod",
                "CustomAttributes": "myattr",
            },
        ],
    }
    requests.post(
        "http://motoapi.amazonaws.com/moto-api/static/sagemaker/endpoint-results",
        json=expected_results,
    )

    assert handler(event, context) == ""

I would expect to see 'foo' in the 'Body' field of the response and 'myattr' in the 'CustomAttributes'. What I get is

{
  'ResponseMetadata': 
  {
    'RequestId': 'h9IaGTpkDLOb1ipWbUOjITVQnfPKzLFzuwbGhLVW4R7H8bL13w8g', 
    'HTTPStatusCode': 200, 
    'HTTPHeaders': 
    {
      'server': 'amazon.com', 
      'date': 'Mon, 31 Mar 2025 10:35:02 GMT', 
      'content-type': 'content_type', 
      'x-amzn-invoked-production-variant': 'invoked_production_variant', 
      'x-amzn-sagemaker-custom-attributes': 'custom_attributes', 
      'x-amzn-requestid': 'h9IaGTpkDLOb1ipWbUOjITVQnfPKzLFzuwbGhLVW4R7H8bL13w8g'
    }, 
    'RetryAttempts': 0
  }, 
  'ContentType': 'content_type', 
  'InvokedProductionVariant': 'invoked_production_variant', 
  'CustomAttributes': 'custom_attributes', 
  'Body': <botocore.response.StreamingBody object at 0x109ff3400>
}

And for content of 'Body': 'body'

Used versions: moto: 5.1.1 boto3: 1.35.55 pytest: 8.3.3

LiMuBei avatar Mar 31 '25 08:03 LiMuBei

Hi @LiMuBei, the test uses a custom account_id, but it's different from the one that Moto uses. It does work if you change that:

- monkeypatch.setenv("AWS_ACCOUNT_ID", "1234567890")
+ monkeypatch.setenv("AWS_ACCOUNT_ID", "123456789012")

If you want to stick with a custom account ID, it is possible to override it using a Moto-specific environment variable called MOTO_ACCOUNT_ID. The documentation has a little bit more info, if you're curious: https://docs.getmoto.org/en/latest/docs/multi_account.html

bblommers avatar Apr 04 '25 20:04 bblommers

Hi @bblommers, I changed the account ID, and I'm still not getting my custom response. When setting the expected results, shouldn't there be anything mentioning the endpoint name? Or would it just use those results for all endpoints?

LiMuBei avatar Apr 07 '25 08:04 LiMuBei

Hi @LiMuBei, the endpoint name might not be known, that's why it's never mentioned. The API request sends a list of expected results, and the first call to invoke_endpoint simply returns the first result of that list.

For reference, this is the full test that passes on my machine. Changes:

  • Explicitly read() the body, and assert the response is correct
  • Removes the account_id from the POST request - if not set, it will simply use the default. This avoids any typos.
import json
import boto3
import os
from unittest.mock import Mock
from moto import mock_aws
import pytest
import requests


sagemaker_runtime = boto3.client("sagemaker-runtime")


def handler(event, context):
    input_payload = json.dumps({"bucket": "my_bucket", "key": "my_image"})
    # Invoke SageMaker endpoint
    response = sagemaker_runtime.invoke_endpoint(
        EndpointName=os.environ["SAGEMAKER_INFERENCE_ENDPOINT_NAME"],
        ContentType="application/json",
        Body=input_payload,
    )

    assert response["CustomAttributes"] == "myattr"
    assert response["Body"].read() == b"foo"


@pytest.fixture
def context():
    return Mock(
        aws_request_id="test-request-id",
        log_stream_name="test-log-stream",
        log_group_name="test-log-group",
        function_name="test-function",
        function_version="test-version",
        invoked_function_arn="test-arn",
        memory_limit_in_mb=128,
        get_remaining_time_in_millis=Mock(return_value=5000),
    )


@pytest.fixture
def event():
    return {}


@mock_aws
def test_working_case(monkeypatch, event, context):
    monkeypatch.setenv("AWS_ACCOUNT_ID", "1234567890")
    monkeypatch.setenv("AWS_REGION", "us-east-1")
    monkeypatch.setenv("SAGEMAKER_INFERENCE_ENDPOINT_NAME", "sm-endpoint")

    # Setup Sagemaker mock response
    expected_results = {
        "region": os.environ["AWS_REGION"],
        "results": [
            {
                "Body": "foo",
                "ContentType": "application/json",
                "InvokedProductionVariant": "prod",
                "CustomAttributes": "myattr",
            },
        ],
    }
    requests.post(
        "http://motoapi.amazonaws.com/moto-api/static/sagemaker/endpoint-results",
        json=expected_results,
    )

    handler(event, context)

bblommers avatar Apr 07 '25 20:04 bblommers

With the code you provided, I'm getting a 'No credentials found' error. If I move the boto3.client initialization inside the handler function, I get the following:

event = {}, context = <Mock id='4567592400'>

    def handler(event, context):
        sagemaker_runtime = boto3.client("sagemaker-runtime")
        input_payload = json.dumps({"bucket": "my_bucket", "key": "my_image"})
        # Invoke SageMaker endpoint
        response = sagemaker_runtime.invoke_endpoint(
            EndpointName=os.environ["SAGEMAKER_INFERENCE_ENDPOINT_NAME"],
            ContentType="application/json",
            Body=input_payload,
        )
    
>       assert response["CustomAttributes"] == "myattr"
E       AssertionError: assert 'custom_attributes' == 'myattr'
E         
E         - myattr
E         + custom_attributes

tests/lambda/test_dummy.py:22: AssertionError

I can even add an assertion to the requests.post like so

result = requests.post(
        "http://motoapi.amazonaws.com/moto-api/static/sagemaker/endpoint-results",
        json=expected_results,
    )
    assert result.status_code == 201

And this assertion holds.

Seems like something in my setup is different from yours.

I'm using uv in a venv with Python 3.11 to run this. Could this have an effect? Is there any way to find out if the queue with the expected responses is actually present and filled?

In any case, thank you already for taking the time to look into this!

LiMuBei avatar Apr 08 '25 07:04 LiMuBei

Hmm.. interesting. If you replace the handler(..)-call with this code, you can see what the queue looks like @LiMuBei :

from moto.sagemakerruntime.models import sagemakerruntime_backends

for account, regions in sagemakerruntime_backends.items():
    for region in regions.keys():
        backend = sagemakerruntime_backends[account][region]
        print(f"\t=== {account} :: {region}")
        print(backend.results)
        print(backend.results_queue)
        print("")

handler(event, context)

for account, regions in sagemakerruntime_backends.items():
    for region in regions:
        backend = sagemakerruntime_backends[account][region]
        print(f"\t=== {account} :: {region}")
        print(backend.results)
        print(backend.results_queue)
        print("")

This will print both the queue and the results, for every account and region.

For reference, this is the output that I would expect:

=== 123456789012 :: us-east-1 {} [('foo', 'application/json', 'prod', 'myattr')]

=== 123456789012 :: us-east-1 {None: {b'eyJBY2NlcHQiOiBudWxsLCAiQm9keSI6ICJ7XCJidWNrZXRcIjogXCJteV9idWNrZXRcIiwgXCJrZXlcIjogXCJteV9pbWFnZVwifSJ9': ('foo', 'application/json', 'prod', 'myattr')}} []

What does the output look like when you run it? There should be a queue with the expected result, if the POST request succeeds - we just need to find out in which account/region this happens.

bblommers avatar Apr 12 '25 11:04 bblommers

Closing due to lack of response. If this is still an issue, just reply to the questions I posed earlier, and we can revisit this.

bblommers avatar Oct 18 '25 11:10 bblommers