whitebox icon indicating copy to clipboard operation
whitebox copied to clipboard

Perform global explainability on the inference dataset

Open NickNtamp opened this issue 2 years ago • 2 comments

As for now, the pipeline performs expainability per inference row. Explainability for the whole inference dataset maybe be useful in the future. Some related code has already be written:

  • Pipeline
def create_xai_pipeline_classification_per_inference_dataset(training_set: pd.DataFrame, target: str, inference_set: pd.DataFrame, type_of_task: str, load_from_path = None
)-> Dict[str, Dict[str, float]]:
    
    xai_dataset=training_set.drop(columns=[target])
    explainability_report={}

    # Make a mapping dict which will be used lated to map the explainer index
    # with the features names

    mapping_dict={}
    for feature in range (0,len(xai_dataset.columns.tolist())):
        mapping_dict[feature]=xai_dataset.columns.tolist()[feature]


    # Expainability for both classifications tasks
    # We have again to revisit here in the future as in case we upload the model
    # from the file system we don't care if it is binary or multiclass

    if type_of_task=='multiclass_classification':
        
        # Giving the option of retrieving the local model

        if load_from_path != None:
            model = joblib.load('{}/lgb_multi.pkl'.format(load_from_path))
        else:
            model, eval = create_multiclass_classification_training_model_pipeline(training_set, target)
            explainer = lime.lime_tabular.LimeTabularExplainer(xai_dataset.values, feature_names=xai_dataset.columns.values.tolist(), mode="classification",random_state=1)
        
        for inference_row in range(0,len(inference_set)):
            exp = explainer.explain_instance(inference_set.values[inference_row], model.predict)
            med_report=exp.as_map()
            temp_dict = dict(list(med_report.values())[0])
            map_dict = {mapping_dict[name]: val for name, val in temp_dict.items()}
            explainability_report["row{}".format(inference_row)]= map_dict
               

    elif type_of_task=='binary_classification':     
        
        # Giving the option of retrieving the local model

        if load_from_path != None:
            model = joblib.load('{}/lgb_binary.pkl'.format(load_from_path))
        else:
            model, eval = create_binary_classification_training_model_pipeline(training_set, target) 
            explainer = lime.lime_tabular.LimeTabularExplainer(xai_dataset.values, feature_names=xai_dataset.columns.values.tolist(), mode="classification",random_state=1)

        for inference_row in range(0,len(inference_set)):
            exp = explainer.explain_instance(inference_set.values[inference_row], model.predict_proba)
            med_report=exp.as_map()
            temp_dict = dict(list(med_report.values())[0])
            map_dict = {mapping_dict[name]: val for name, val in temp_dict.items()}
            explainability_report["row{}".format(inference_row)]= map_dict

            
    return explainability_report 
  • Unit tests
def test_create_xai_pipeline_classification_per_inference_dataset(self):
        binary_class_report =create_xai_pipeline_classification(df_binary,"target",df_binary_inference,"binary_classification")
        multi_class_report=create_xai_pipeline_classification(df_multi,"target",df_multi_inference,"multiclass_classification")
        binary_contribution_check_one = binary_class_report["row0"]["worst perimeter"]
        binary_contribution_check_two = binary_class_report["row2"]['worst texture']
        multi_contribution_check_one = multi_class_report["row0"]["hue"]
        multi_contribution_check_two = multi_class_report["row9"]["proanthocyanins"]
        assert (len(binary_class_report)) == len(df_binary_inference)
        assert (len(multi_class_report)) == len(df_multi_inference)
        assert (round(binary_contribution_check_one, 3)) == 0.253
        assert (round(binary_contribution_check_two, 2)) == -0.09
        assert (round(multi_contribution_check_one, 2)) == -0.08
        assert (round(multi_contribution_check_two, 3)) == -0.023

NickNtamp avatar Dec 07 '22 08:12 NickNtamp

@momegas I would like to start working on this

aditkay95 avatar Feb 16 '23 19:02 aditkay95

That would be great. Please start by opening a PR with the proposed changes (draft). Ill assign this to you. Thanks 🙏

momegas avatar Feb 17 '23 09:02 momegas