sagemaker-python-sdk icon indicating copy to clipboard operation
sagemaker-python-sdk copied to clipboard

Enable passing column type to SHAPConfig in combination with ClarifyCheckStep

Open oskarklang-private opened this issue 8 months ago • 0 comments

Describe the feature you'd like Add a parameter to SHAPConfig from sagemaker.workflow.clarify_checkstep which lets the user specify the types of the dataset used to create a baseline for the SHAP analysis (e.g. float, int, category, etc..). Alternatively, make it possible to run ClarifyCheckStep when an S3 URI has been passed as baseline to SHAPConfig.

How would this feature be used? Please describe. When using the ClarifyCheckStep and SHAPConfig from sagemaker.workflow.clarify_checkstep, I am currently unable to specify my dataset's column types (e.g. some columns should be numerical while others should be categorical).

When running the ClarifyCheckStep as part of a SageMaker pipeline, Clarify calculates a baseline which is erroneous due to not having taken the column types into account, so e.g. some columns that should be categorical gets the mean of the column as baseline, where preferrably they should get the mode of the column or something else more appropriate.

I know that I can pass my own baseline to SHAPConfig, but I don't want this hard coded in my SageMaker pipeline definition - I want it to be computed at runtime, based on previous steps in my SageMaker pipeline. An alternative solution would be to pass to SHAPConfig the S3 URI to a baseline dataset I create in a previous step, however this doesn't seem to work with how ClarifyCheckStep is currently implemented.

Describe alternatives you've considered Make it possible to run ClarifyCheckStep when an S3 URI has been passed as baseline to SHAPConfig.

Additional context

from sagemaker.workflow.clarify_check_step import ClarifyCheckStep, ModelExplainabilityCheckConfig, SHAPConfig

shap_config = SHAPConfig(seed=123, num_samples=100, num_clusters=5)

model_explainability_check_config = ModelExplainabilityCheckConfig(
      data_config=model_explainability_data_config,
      model_config=model_config,
      explainability_config=shap_config,
  )

step_model_explainability_check = ClarifyCheckStep(
        name="ModelExplainabilityCheckStep",
        display_name="Model Explainability Check",
        clarify_check_config=model_explainability_check_config,
        check_job_config=check_job_config_clarify,
        skip_check=skipCheckModelExplainabilityParam,
        register_new_baseline=registerNewBaselineModelExplainabilityParam,
        supplied_baseline_constraints=suppliedBaselineConstraintsModelExplainabilityParam,
        model_package_group_name=model_package_group_name,
    )

oskarklang-private avatar Apr 16 '25 12:04 oskarklang-private