DiCE icon indicating copy to clipboard operation
DiCE copied to clipboard

Feature ranges

Open PaoloFantine opened this issue 2 years ago • 5 comments

I am using the drug consumption dataset for a multiclass classification and trying to enforce some permitted ranges like:

d = dice_ml.Data(dataframe = eval_data, 
                 continuous_features = continuous_features, 
                 outcome_name = self.target_name)
        
 m = dice_ml.Model(model = model, backend = backend)
 
 exp = dice_ml.Dice(d, m, method="random")
        
 e1 = exp.generate_counterfactuals(instance = 0,
                   continuous_features = ['nscore', 'escore', 'oscore', 'ascore', 'cscore', 'impulsive', 'ss'],
                   model = model, 
                   backend = "sklearn", 
                   total_CFs = 5,
                   desired_class = 2,
                   features_to_vary = ['nscore', 'escore', 'oscore', 'ascore', 'cscore', 'impulsive', 'ss'],
                   permitted_range = {'oscore':[-1.0, 2.0], 'ascore':[0.0, 0.5], 'cscore':[0.0, 1.0], 'ss':[-2.0, 1.0 ]})

however, I get counterfactuals with values well outside those ranges. I wonder whether this is a bug or, in case there are no available counterfactuals within those ranges, they are somehow overidden?

PaoloFantine avatar Apr 05 '22 14:04 PaoloFantine

@PaoloFantine there are no parameters like 'continuous_features', 'instance', 'model', 'backend' in the generate_counterfactuals method. Could you may be use the sample notebooks from https://github.com/interpretml/DiCE/blob/master/docs/source/notebooks/DiCE_getting_started.ipynb to correct the above code?

gaugup avatar Apr 07 '22 07:04 gaugup

``> @PaoloFantine there are no parameters like 'continuous_features', 'instance', 'model', 'backend' in the generate_counterfactuals method. Could you may be use the sample notebooks from https://github.com/interpretml/DiCE/blob/master/docs/source/notebooks/DiCE_getting_started.ipynb to correct the above code?

That function is a wrapper I use from within a class. My basd for not checking.

e1 = exp.generate_counterfactuals(self.X[instance:instance+1], 
                                          total_CFs = total_CFs, 
                                          desired_class=desired_class,
                                          desired_range=None,
                                          permitted_range=permitted_range,  
                                          features_to_vary=features_to_vary,
                                          stopping_threshold=stopping_threshold, 
                                          posthoc_sparsity_param=posthoc_sparsity_param,
                                          posthoc_sparsity_algorithm=posthoc_sparsity_algorithm,
                                          verbose=False)

where I think you can deduce the parameters from my previous comment. The issue is the same: I get counterfactuals outside the permitted_range and wonder if it is somehow overridden for the sake of finding counterfactuals

PaoloFantine avatar Apr 07 '22 13:04 PaoloFantine

We should be respecting the range. If you could share a sample notebook and dataset, it will be possible for us to dig into this issue.

Regards,

gaugup avatar Apr 07 '22 21:04 gaugup

We should be respecting the range. If you could share a sample notebook and dataset, it will be possible for us to dig into this issue.

Regards,

apart from the ranges for the continuous variables, how can I set the range for categorical variables? for example I have a categorical variable that contains four categories (A, B, C, D) and I only want the counterfactuals to vary to A, B, C excluding D

tonyabracadabra avatar Apr 13 '22 09:04 tonyabracadabra

@tonyabracadabra, You can do something like this below:-

            exp.generate_counterfactuals(
                query_instances=sample_custom_query_1,
                total_CFs=10,
                desired_class=desired_class,
                desired_range=None,
                permitted_range={'Categorical': ['A', 'B', 'C']},
                features_to_vary='all')

That should be a signal to dice to use only categories 'A', 'B' and 'C' to generate counterfactuals.

Hope this helps.

P.S. Since this question is not related to the original question, we will really appreciate if you opened a new issue for these queries.

gaugup avatar Apr 20 '22 08:04 gaugup