alibi icon indicating copy to clipboard operation
alibi copied to clipboard

IndexError: tuple index out of range

Open pranavn91 opened this issue 3 years ago • 2 comments

I used scikit-learn 0.24.2 to train a random forest classifier and used CounterfactualProto - as given in below link

(https://docs.seldon.io/projects/alibi/en/stable/examples/cfproto_housing.html)

from alibi.explainers import CounterfactualProto y30cf = np.zeros((y30.shape[0],)) y30cf[np.where(y30 > np.median(y30))[0]] = 1

y becomes classification task

y30cf array([1., 0., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.])

model trained and ready

bestmodel RandomForestClassifier(ccp_alpha=0.1, criterion='entropy', max_depth=1, max_features=None, max_samples=0.5, min_impurity_decrease=0.1, min_samples_leaf=0.5, min_samples_split=0.1, min_weight_fraction_leaf=0.5, n_estimators=1)

took one sample

X = X_test[1].reshape((1,) + X_test[1].shape) shape = X.shape

shape (1, 39)

X array([[ 0. , 0. , 0. , 25. , 0. , 75. , 0. , 0. , 0. , 0. , 70. , 0. , 0. , 0. , 0. , 55. , 0. , 0. , 40. , 0. , 0. , 0. , 0. , 45. , 0. , 1.5, 2. , 0. , 3. , 0. , 0. , 0.5, 0. , 0. , 0. , 0.3, 0. , 0. , 8. ]]) 1.

define a black-box model

predict_fn = lambda x: bestmodel.predict(x)

I am getting IndexError: tuple index out of range. How to resolve this error?

cf = CounterfactualProto(predict_fn, shape, use_kdtree=True, theta=10., max_iterations=1000, feature_range=(X_train.min(axis=0), X_train.max(axis=0)), c_init=1., c_steps=10)


IndexError Traceback (most recent call last) Input In [27], in <cell line: 2>() 1 # initialize explainer, fit and generate counterfactual ----> 2 cf = CounterfactualProto(predict_fn, shape, use_kdtree=True, theta=10., max_iterations=1000, 3 feature_range=(X_train.min(axis=0), X_train.max(axis=0)), 4 c_init=1., c_steps=10)

File C:\ProgramData\Anaconda3\lib\site-packages\alibi\explainers\cfproto.py:139, in CounterfactualProto.init(self, predict, shape, kappa, beta, feature_range, gamma, ae_model, enc_model, theta, cat_vars, ohe, use_kdtree, learning_rate_init, max_iterations, c_init, c_steps, eps, clip, update_num_grad, write_dir, sess) 137 else: # black-box model 138 self.model = False --> 139 self.classes = self.predict(np.zeros(shape)).shape[1] 141 if is_enc: 142 self.enc_model = True

IndexError: tuple index out of range

pranavn91 avatar Oct 24 '22 10:10 pranavn91

Hey @pranavn91, Thanks for opening the issue.

The Counterfactuals with prototypes require that the model, be it a black or white box model, be differentiable which random forests aren't. Hence I'm not sure you'll get good results using this method. You might want to look at the CounterfactualRL method instead. (This example uses a random forest classifier)

w.r.t. your issue I'm finding it difficult to recreate. The following:

from sklearn.ensemble import RandomForestClassifier

X = x_test[1].reshape((1,) + x_test[1].shape)
shape = X.shape


clf = RandomForestClassifier(ccp_alpha=0.1, criterion='entropy', max_depth=1,
        max_features=None, max_samples=0.5,
        min_impurity_decrease=0.1, min_samples_leaf=0.5,
        min_samples_split=0.1, min_weight_fraction_leaf=0.5,
        n_estimators=1)

clf.fit(x_train, y_train)

predict_fn = lambda x: clf.predict(x)
predict_fn(np.zeros(shape)).shape[1]

cf = CounterfactualProto(predict_fn, shape, use_kdtree=True, theta=10., max_iterations=1000,
                         feature_range=(x_train.min(axis=0), x_train.max(axis=0)), 
                         c_init=1., c_steps=10)

Doesn't throw the same error. How does your code differ exactly? If you could copy and paste the entire code altogether it might help.

mauicv avatar Oct 24 '22 11:10 mauicv

Thanks I will try the new link. The error goes if i one-hot encode the labels. I did not do this as the data was binary. My bad.

pranavn91 avatar Oct 24 '22 12:10 pranavn91