sagemaker-python-sdk icon indicating copy to clipboard operation
sagemaker-python-sdk copied to clipboard

Create Awaitable predict capability

Open hooman-bayer opened this issue 1 year ago • 10 comments

Describe the feature you'd like Like many other inference libraries in python (e.g. OpenAI), create a real awaitable version of Predict for realtime sagemaker inference endpoints. This will help python applications that use FastAPI and asyncio to deliver realtime responses while not blocking the main event loop. Please note that this feature is different that the one currently available here where the predictions are written to a S3 bucket. This feature would work exactly like https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.Predictor.predict but with an await in real asyncio style.

Sagemaker is an amazing library and it would be just way better for production environments using FastAPI to have this feature.

How would this feature be used? Please describe. In this case, currently, the sync version looks like this:

response = predictor.predict(input_data)

The async might be looking like

response = await predictor.apredict(input_data)

Describe alternatives you've considered I considered subclassing the predictor and add the async version.

Additional context For modern python applications building on top of FastAPI and Asyncio, it is crucial to use async modalities do avoid blocking the main event-loop in the server (in case of scalable applications). Therefore, having a real awaitable functionality would avoid blocking the main event loop of the applications that leverage sagemaker.

Thanks alot

hooman-bayer avatar Jul 01 '23 20:07 hooman-bayer

If you're using FastAPI, the author (Tiangolo) has a nice project called Asyncer, which has a very nice asyncify function. This is just a wrapper on top of anyio which does the heavy lifting, but it's trivial to call the sync Sagemaker i/o within asyncio flows. Here's an example with Huggingface:


app = FastAPI()

@app.get("/completion")
async def get_completion():
    return await get_huggingface_completion("Please give me the stuff!")

async def get_huggingface_completion(total_prompt: str) -> dict[str, Any]:
    response = await asyncify(_communicate_with_sagemaker)(total_prompt)
    return _get_completion(response, total_prompt)

def _communicate_with_sagemaker(total_prompt: str) -> Any:
    session = boto3.Session(
        region_name="us-west-2",
    )
    sage_session = sagemaker.Session(boto_session=session)

    predictor = HuggingFacePredictor("your-endpoint", sagemaker_session=sage_session)
    payload = {"inputs": total_prompt, "parameters": {"max_new_tokens": 1024}}

    return cast(dict[str, Any], predictor.predict(payload))

phillipuniverse avatar Sep 01 '23 06:09 phillipuniverse

If you're using FastAPI, the author (Tiangolo) has a nice project called Asyncer, which has a very nice asyncify function. This is just a wrapper on top of anyio which does the heavy lifting, but it's trivial to call the sync Sagemaker i/o within asyncio flows. Here's an example with Huggingface:

app = FastAPI()

@app.get("/completion")
async def get_completion():
    return await get_huggingface_completion("Please give me the stuff!")

async def get_huggingface_completion(total_prompt: str) -> dict[str, Any]:
    response = await asyncify(_communicate_with_sagemaker)(total_prompt)
    return _get_completion(response, total_prompt)

def _communicate_with_sagemaker(total_prompt: str) -> Any:
    session = boto3.Session(
        region_name="us-west-2",
    )
    sage_session = sagemaker.Session(boto_session=session)

    predictor = HuggingFacePredictor("your-endpoint", sagemaker_session=sage_session)
    payload = {"inputs": total_prompt, "parameters": {"max_new_tokens": 1024}}

    return cast(dict[str, Any], predictor.predict(payload))

With

@phillipuniverse thanks for the suggestion. Sure that helps but still it will run it on a different thread and its still different than a pure async module which is more optimized for python.

hooman-bayer avatar Sep 17 '23 19:09 hooman-bayer

Dropping a comment in case anyone else happens to stumble across this thread. ~~The suggestion to use asyncify unfortunately didn't work.~~ It appeared to be functional under low concurrency loads, but started failing a large percentage of requests once increasing the number of concurrent requests above ~5.

Going to try directly issuing requests over HTTP with an async aware library and see if that does the trick.

ewellinger avatar Jan 05 '24 01:01 ewellinger

Turns out that may not have been correct. After a lot of frustration it appears that the errors that were getting generated were actually coming from the Sagemaker endpoint (which in hindsight seems kind of obvious). The immediacy of the errors made me think it was something to do with how they were being initiated but now I don't think that was true.

In case this helps anyone else, this was how I implemented it by leveraging aiohttp:

import hashlib
import json

import aiohttp
import boto3
from aws_request_signer import AwsRequestSigner

region_name = "region"
endpoint_name = "endpoint_name"
payload = {"inputs": "Test", "parameters": {}}
sagemaker_endpoint_url = f"https://runtime.sagemaker.{region_name}.amazonaws.com/endpoints/{endpoint_name}/invocations"

session = aiohttp.ClientSession()
_refreshable_credentials = boto3.Session(region_name=region_name).get_credentials()

# Get signed headers
creds = _refreshable_credentials.get_frozen_credentials()
signer = AwsRequestSigner(
    region=region_name,
    access_key_id=creds.access_key,
    secret_access_key=creds.secret_key,
    session_token=creds.token,
    service="sagemaker",
)
payload_bytes = json.dumps(payload).encode("utf-8")
payload_hash = hashlib.sha256(payload_bytes).hexdigest()
headers = {"Content-Type": "application/json"}
headers.update(
    signer.sign_with_headers("POST", sagemaker_endpoint_url, headers, payload_hash)
)

try:
    async with session.post(sagemaker_endpoint_url, headers=headers, json=payload) as response:
        response.raise_for_status()
        return await response.json()
except aiohttp.ClientError as e:
    raise RuntimeError(f"Request to SageMaker endpoint failed: {e}") from e
except Exception as e:
    raise RuntimeError(f"An error occurred: {e}") from e

Only slightly pseudo code but that gives the general idea of how to go about signing the headers and would be easy to wrap in an async endpoint.

ewellinger avatar Jan 05 '24 18:01 ewellinger

Thanks a lot @ewellinger . Have you tried to benchmark your approach against the following using asyncify ?

app = FastAPI()

@app.get("/completion")
async def get_completion():
    return await get_huggingface_completion("Please give me the stuff!")

async def get_huggingface_completion(total_prompt: str) -> dict[str, Any]:
    response = await asyncify(_communicate_with_sagemaker)(total_prompt)
    return _get_completion(response, total_prompt)

def _communicate_with_sagemaker(total_prompt: str) -> Any:
    session = boto3.Session(
        region_name="us-west-2",
    )
    sage_session = sagemaker.Session(boto_session=session)

    predictor = HuggingFacePredictor("your-endpoint", sagemaker_session=sage_session)
    payload = {"inputs": total_prompt, "parameters": {"max_new_tokens": 1024}}

    return cast(dict[str, Any], predictor.predict(payload))

I assume yours would be performing way better. It is beyond me why SageMaker would not offer such a basic feature. On one hand, I assume they want it to be used with LLMs but on the other hand no support for basic needs of it (e.g. async and streaming)

hooman-bayer avatar Jan 05 '24 19:01 hooman-bayer

I haven't, I only realized what was happening pretty late last night.

We're hitting these endpoints for LLM predictions so the latency is already pretty high and I'd imagine the difference between directly using aiohttp and asyncify would probably be negligible in this case.

ewellinger avatar Jan 05 '24 19:01 ewellinger

Just a really quick and dirty comparison between the two implementations. This was hitting a LLM with 20 total requests, running 5 requests concurrently at a time.

This was leveraging the asyncer approach:

Summary:
  Total:	60.2530 secs
  Slowest:	18.1773 secs
  Fastest:	9.3643 secs
  Average:	13.5625 secs
  Requests/sec:	0.3319

  Total data:	45417 bytes
  Size/request:	2390 bytes

Response time histogram:
  9.364 [1]	|■■■■■■■■■■
  10.246 [0]	|
  11.127 [2]	|■■■■■■■■■■■■■■■■■■■■
  12.008 [3]	|■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
  12.889 [1]	|■■■■■■■■■■
  13.771 [3]	|■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
  14.652 [4]	|■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
  15.533 [0]	|
  16.415 [3]	|■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
  17.296 [0]	|
  18.177 [2]	|■■■■■■■■■■■■■■■■■■■■


Latency distribution:
  10% in 10.3807 secs
  25% in 11.7108 secs
  50% in 13.7912 secs
  75% in 15.7309 secs
  90% in 18.1773 secs
  0% in 0.0000 secs
  0% in 0.0000 secs

Details (average, fastest, slowest):
  DNS+dialup:	0.0024 secs, 9.3643 secs, 18.1773 secs
  DNS-lookup:	0.0017 secs, 0.0000 secs, 0.0083 secs
  req write:	0.0004 secs, 0.0000 secs, 0.0034 secs
  resp wait:	13.5570 secs, 9.3516 secs, 18.1769 secs
  resp read:	0.0024 secs, 0.0002 secs, 0.0170 secs

Status code distribution:
  [200]	19 responses

Error distribution:
  [1]	Post "http://localhost:8000/api/llm/text_generation/v1": context deadline exceeded (Client.Timeout exceeded while awaiting headers)

Here was the breakdown with aiohttp:

Summary:
  Total:	50.7949 secs
  Slowest:	16.6846 secs
  Fastest:	5.4675 secs
  Average:	11.4514 secs
  Requests/sec:	0.3937

  Total data:	47699 bytes
  Size/request:	2384 bytes

Response time histogram:
  5.468 [1]	|■■■■■■
  6.589 [0]	|
  7.711 [1]	|■■■■■■
  8.833 [1]	|■■■■■■
  9.954 [2]	|■■■■■■■■■■■
  11.076 [1]	|■■■■■■
  12.198 [7]	|■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
  13.319 [3]	|■■■■■■■■■■■■■■■■■
  14.441 [1]	|■■■■■■
  15.563 [2]	|■■■■■■■■■■■
  16.685 [1]	|■■■■■■


Latency distribution:
  10% in 8.8009 secs
  25% in 10.2199 secs
  50% in 11.4831 secs
  75% in 13.1079 secs
  90% in 15.2409 secs
  95% in 16.6846 secs
  0% in 0.0000 secs

Details (average, fastest, slowest):
  DNS+dialup:	0.0023 secs, 5.4675 secs, 16.6846 secs
  DNS-lookup:	0.0020 secs, 0.0000 secs, 0.0093 secs
  req write:	0.0002 secs, 0.0000 secs, 0.0017 secs
  resp wait:	11.4449 secs, 5.4571 secs, 16.6785 secs
  resp read:	0.0039 secs, 0.0001 secs, 0.0286 secs

Status code distribution:
  [200]	20 responses

So it does look like there is a benefit to using aiohttp but would probably need more extensive testing to say how large the difference would be.

Also I can confirm that using asyncer does, in fact, work. The issues I hit previously were actually on the endpoint side.

ewellinger avatar Jan 05 '24 20:01 ewellinger

Awesome @ewellinger , thanks a lot! looks like your approach is slightly better. SageMaker has recently introduced invoke_endpoint_with_response_stream but still is a synchronous operation (in sagemaker python sdk) but with your approach one can get close to a decent streaming and async approach (something like below):

import hashlib
import json
import aiohttp
import boto3
from aws_request_signer import AwsRequestSigner

async def invoke_sagemaker_stream(region_name, endpoint_name, payload):
    sagemaker_endpoint_url = f"https://runtime.sagemaker.{region_name}.amazonaws.com/endpoints/{endpoint_name}/invocations-response-stream"

    async with aiohttp.ClientSession() as session:
        _refreshable_credentials = boto3.Session(region_name=region_name).get_credentials()

        # Get signed headers
        creds = _refreshable_credentials.get_frozen_credentials()
        signer = AwsRequestSigner(
            region=region_name,
            access_key_id=creds.access_key,
            secret_access_key=creds.secret_key,
            session_token=creds.token,
            service="sagemaker",
        )
        payload_bytes = json.dumps(payload).encode("utf-8")
        payload_hash = hashlib.sha256(payload_bytes).hexdigest()
        headers = {"Content-Type": "application/json"}
        headers.update(
            signer.sign_with_headers("POST", sagemaker_endpoint_url, headers, payload_hash)
        )

        try:
            async with session.post(sagemaker_endpoint_url, headers=headers, json=payload) as response:
                response.raise_for_status()
                # Now, instead of returning a JSON response, we handle the stream.
                async for line in response.content:
                    print(line.decode('utf-8'))
                    # Process each line here as needed
        except aiohttp.ClientError as e:
            raise RuntimeError(f"Request to SageMaker endpoint failed: {e}") from e
        except Exception as e:
            raise RuntimeError(f"An error occurred: {e}") from e

# Example usage:
# asyncio.run(invoke_sagemaker_stream("your-region", "your-endpoint", {"inputs": "Test", "parameters": {}}))

hooman-bayer avatar Jan 07 '24 03:01 hooman-bayer

+1 on this, would be great to have

jameseggers avatar Feb 14 '24 14:02 jameseggers

I'm not terribly hopeful that this will be implemented in a timely manner.

My company has priority support and I raised a support ticket to have them respond to this feature request and it was closed not too long after without them chiming in here. I didn't have the time for the back and forth since the implementation above technically works :/

ewellinger avatar Feb 14 '24 19:02 ewellinger