Skater icon indicating copy to clipboard operation
Skater copied to clipboard

Surrogate Tree Explainer "plot_global_decisions" fails to generate

Open christoferjulio3 opened this issue 3 years ago • 0 comments

I try to generate the Surrogate Tree explainer based on your example code from GitHub but it fails.

Below is the code:

%matplotlib inline
import matplotlib.pyplot
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
from sklearn import svm
from skater.core.explanations import Interpretation
from skater.model import InMemoryModel

from skater.core.global_interpretation.tree_surrogate import TreeSurrogate
from skater.util.dataops import show_in_notebook

iris = datasets.load_iris()
digits = datasets.load_digits()
X = iris.data
y = iris.target

clf = RandomForestClassifier(random_state=0, n_jobs=-1)

xtrain, xtest, ytrain, ytest = train_test_split(X,y,test_size=0.2, random_state=0)
clf = clf.fit(xtrain, ytrain)

y_pred=clf.predict(xtest)
prob=clf.predict_proba(xtest)

interpreter = Interpretation(
        training_data=xtrain, training_labels=ytrain, feature_names=iris.feature_names
    )
pyint_model = InMemoryModel(
            clf.predict_proba,
            examples=xtrain,
            target_names=iris.target_names,
            unique_values=np.unique(ytrain).tolist(),
            feature_names=iris.feature_names,
        )

surrogate_explainer = interpreter.tree_surrogate(oracle=pyint_model, seed=5)
surrogate_explainer.fit(xtrain, ytrain)
`surrogate_explainer.plot_global_decisions(show_img=True)

And this is the error generated from the code:

2022-03-30 01:17:01,327 - skater.core.global_interpretation.tree_surrogate - INFO - post pruning applied ...
2022-03-30 01:17:01,332 - skater.core.global_interpretation.tree_surrogate - INFO - Scorer used cross-entropy
2022-03-30 01:17:01,342 - skater.core.global_interpretation.tree_surrogate - INFO - original score using base model 2.1094237467877998e-15
2022-03-30 01:17:01,388 - skater.core.global_interpretation.tree_surrogate - INFO - Summary: childrens of the following nodes are removed []
2022-03-30 01:17:01,392 - skater.core.global_interpretation.tree_surrogate - INFO - Done generating prediction using the surrogate, shape (120, 3)
2022-03-30 01:17:01,398 - skater.core.global_interpretation.tree_surrogate - INFO - Done scoring, surrogate score 0.0; oracle score 0.033
2022-03-30 01:17:01,401 - skater.core.global_interpretation.tree_surrogate - WARNING - impurity score: 0.033 of the surrogate model is higher than the impurity threshold: 0.01. The higher the impurity score, lower is the fidelity/faithfulness of the surrogate model
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-13-dd8a23b7ccad>](https://localhost:8080/#) in <module>()
     39 surrogate_explainer = interpreter.tree_surrogate(oracle=pyint_model, seed=5)
     40 surrogate_explainer.fit(xtrain, ytrain)
---> 41 surrogate_explainer.plot_global_decisions(show_img=True)

2 frames
[/usr/local/lib/python3.7/dist-packages/skater/core/global_interpretation/tree_surrogate.py](https://localhost:8080/#) in plot_global_decisions(self, colors, enable_node_id, random_state, file_name, show_img, fig_size)
    399         """
    400         graph_inst = plot_tree(self.__model, self.__model_type, feature_names=self.feature_names, color_list=colors,
--> 401                                class_names=self.class_names, enable_node_id=enable_node_id, seed=random_state)
    402         f_name = "interpretable_tree.png" if file_name is None else file_name
    403         graph_inst.write_png(f_name)

[/usr/local/lib/python3.7/dist-packages/skater/core/visualizer/tree_visualizer.py](https://localhost:8080/#) in plot_tree(estimator, estimator_type, feature_names, class_names, color_list, colormap_reg, enable_node_id, coverage, seed)
    105         default_color = None
    106 
--> 107     graph = _set_node_properites(estimator, estimator_type, graph, color_names=colors, default_color=default_color)
    108 
    109     # Set the color scheme for the edges

[/usr/local/lib/python3.7/dist-packages/skater/core/visualizer/tree_visualizer.py](https://localhost:8080/#) in _set_node_properites(estimator, estimator_type, graph_instance, color_names, default_color)
     68         if node.get_name() not in ('node', 'edge'):
     69             if estimator_type == 'classifier':
---> 70                 value = values[int(node.get_name())][0]
     71                 # 1. Color only the leaf nodes, where one class is dominant or if it is a leaf node
     72                 # 2. For mixed population or otherwise set the default color

ValueError: invalid literal for int() with base 10: '"\\n"'

Please kindly take a look. Thank you!

christoferjulio3 avatar Mar 30 '22 01:03 christoferjulio3