segment-anything icon indicating copy to clipboard operation
segment-anything copied to clipboard

Deploying SAM with TorchServe

Open holma91 opened this issue 1 year ago • 2 comments

I am working on deploying the SAM model using TorchServe. My current implementation performs both the image embedding computation and mask prediction in a single request-response cycle, which is not great. It looks something like this:

class SAMHandler(BaseHandler):
    def initialize(self, context):
        model_type = "vit_b"
        model_dir = context.system_properties.get("model_dir")
        self.sam = sam_model_registry[model_type](checkpoint=f"{model_dir}/vit_b.pth")
        self.sam.eval()
        self.predictor = SamPredictor(self.sam)
        self.initialized = True

    def preprocess(self, data):
        image_data = BytesIO(data[0].get("body"))
        image = Image.open(image_data).convert("RGB")
        image = np.array(image)

        # Computes the image embedding
        self.predictor.set_image(image)

        # Dummy input for this example
        point_coords = np.array(data[0].get("point_coords", [[500, 375]]))
        point_labels = np.array(data[0].get("point_labels", [1]))
        return point_coords, point_labels

    def inference(self, model_input):
        point_coords, point_labels = model_input

        # Does the mask prediction
        masks, scores, logits = self.predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            multimask_output=True,
        )
        return masks, scores, logits

    def postprocess(self, inference_output):
        masks, scores, logits = inference_output
        return [[masks.tolist(), logits.tolist(), scores.tolist()]]

As you can see, the image embedding is done in the preprocess function, and the mask prediction is done in the inference function. I would like to separate these two steps into different requests for optimized inference:

Step 1: Generate image embeddings and store them with a uid. Step 2: Use the stored embeddings for quick mask prediction.

What is the optimal way to implement this? The two options I've come up with are:

Option 1: Separate Handlers for Embedding and Prediction

Create two separate TorchServe handlers: one for generating and storing the image embeddings and another for performing mask prediction using the stored embeddings. The client would in this scenario get back the uid for the embedding, and send it along with the request when doing the mask prediction later.

Option 2: Caching Embeddings with Unique Identifiers

Use a single TorchServe handler and let the client specify in the request if the action is 'generate-embedding' or 'predict-mask'. It would look something like this:

def preprocess(self, data):
    action = data[0].get("action")
    if action == 'generate-embedding':
        # Prepare image for embedding generation
    elif action == 'predict-mask':
        # Prepare embedding ID and other inputs for mask prediction
    return prepared_data

def inference(self, data):
    action = data.get("action")
    if action == 'generate-embedding':
        # Generate and store embedding, return identifier
    elif action == 'predict-mask':
        # Run mask prediction using stored embedding
    return result

For both options, I guess I could store the embeddings in something like redis?

What does everyone think of these alternatives? Is there something completely different am missing, that's cleaner? Appreciate all the help I can get!

holma91 avatar Sep 03 '23 14:09 holma91

Hi,

Did you manage to do what you wanted ? Because I'm also looking to deploy the model on a TorcheServe but I don't know what is the best way to do it.

Thanks

JulesVision avatar Sep 14 '23 08:09 JulesVision

The actual model is a separate object and the Predictor just seems to handle the bookkeeping of caching the embedding (see set_image method) and performing the appropriate transformations during inference. Couldn't you instantiate a new Predictor for every image and give them all the same model?

feffy380 avatar Mar 21 '24 01:03 feffy380