unionml icon indicating copy to clipboard operation
unionml copied to clipboard

Post prediction hook that has access to features, model, prediction for e.g. WhyLogs, Revela integration

Open cosmicBboy opened this issue 2 years ago • 2 comments

Use case: As a data scientist, I want to define a custom function hook after the model.predictor function has been called so that I can track metrics and aggregates about the features and predictions, something like:

def my_custom_callback(model: ModelType, features: FeatureType, prediction: PredictionType):
    # do something
    ...

# alternatively, something more flexible would be:
def my_custom_callback(inputs: PredictCallbackInputs):
    # do something
    inputs.model
    inputs.features
    inputs.prediction
    inputs.* # whatever else might be needed in the future
    ...


@model.predictor(post_callback=my_custom_callback)
def predictor(model: ModelType, features: FeatureType) -> PredictionType:
    predictions = ...
    return predictions

cosmicBboy avatar May 18 '22 18:05 cosmicBboy

UnionML OSS Planning Notes:

This functionality is relevant for integrations with:

  • whylogs
  • revela
from unionml.callbacks import RevelaMonitoringCallback

@model.predictor(post_callback=RevelaMonitoringCallback())
def predictor(model: ModelType, features: FeatureType) -> PredictionType:
    predictions = ...
    return predictions

we can fast track this feature to an earlier release if that helps you prototype @zevisert

cosmicBboy avatar Aug 03 '22 16:08 cosmicBboy

That would be great! I can make a PR for the callbacks feature I think -- it seems that the logic would be straight forward:

  1. Accept some callbacks in the @model.predictor decorator
  2. Register them with them with the model object, like @model.predictor does with the actual prediction function
  3. Invoke the callback using either the same flyte task used for prediction, or in some other downstream flyte task (not sure about which is preferred here)

From that, it'd be really easy for me to handle monitoring in a custom callback -- or in unionml.callbacks.RevelaMonitoringCallback if you're down with that!

zevisert avatar Aug 03 '22 18:08 zevisert