dtreeviz icon indicating copy to clipboard operation
dtreeviz copied to clipboard

clfviz throws KeyError, points to color_map index.

Open bridgesra opened this issue 3 years ago • 1 comments

UPDATE: found problem and fix. to fix change line 319 in classifiers.py: color_map = {v: class_colors[i] for i, v in enumerate(class_values)} ## changed from color_map = {i: class_colors[i] for i, v in enumerate(class_values)}

I tried to do a pull request, but in the newest clone it is already fixed! this seems to be a problem only with pip3 install ... version


This error happens with may data with two features (X.shape = (39,2)) and one feature (X.shape = (39, 1)). I printed and gave you X,y at the end if you want to recreate.

Here's the code:

dt = DecisionTreeClassifier(max_depth=2)
dt.fit(X, y)
from dtreeviz import clfviz
clfviz(dt, X, y, feature_names=["a","b"], markers=['o','X','s','D'], target_name="y")

And error message:

/usr/local/lib/python3.9/site-packages/sklearn/base.py:450: UserWarning: X does not have valid feature names, but DecisionTreeClassifier was fitted with feature names
  warnings.warn(
/usr/local/lib/python3.9/site-packages/sklearn/base.py:450: UserWarning: X does not have valid feature names, but DecisionTreeClassifier was fitted with feature names
  warnings.warn(
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Input In [374], in <module>
----> 1 clfviz(dt, X, y, feature_names=features, target_name=target)

File /usr/local/lib/python3.9/site-packages/dtreeviz/classifiers.py:103, in clfviz(model, X, y, ntiles, tile_fraction, binary_threshold, show, feature_names, target_name, class_names, markers, boundary_marker, boundary_markersize, fontsize, fontname, dot_w, yshift, sigma, colors, ranges, ax)
     88     clfviz_univar(model=model, x=X, y=y,
     89                   ntiles=ntiles,
     90                   binary_threshold=binary_threshold,
   (...)
    100                   colors=colors,
    101                   ax=ax)
    102 elif len(X.shape) == 2 and X.shape[1] == 2:
--> 103     clfviz_bivar(model=model, X=X, y=y,
    104                  ntiles=ntiles, tile_fraction=tile_fraction,
    105                  binary_threshold=binary_threshold,
    106                  show=show,
    107                  feature_names=feature_names, target_name=target_name,
    108                  class_names=class_names,
    109                  markers=markers,
    110                  boundary_marker=boundary_marker,
    111                  boundary_markersize=boundary_markersize,
    112                  fontsize=fontsize, fontname=fontname,
    113                  dot_w=dot_w, colors=colors,
    114                  ranges=ranges,
    115                  ax=ax)
    116 else:
    117     raise ValueError(f"Expecting 2D data not {X.shape}")

File /usr/local/lib/python3.9/site-packages/dtreeviz/classifiers.py:188, in clfviz_bivar(model, X, y, ntiles, tile_fraction, binary_threshold, show, feature_names, target_name, class_names, markers, boundary_marker, boundary_markersize, fontsize, fontname, dot_w, colors, ranges, ax)
    184 if 'misclassified' in show:
    185     # Show correctly classified markers
    186     good_x = x_[class_X_pred[i] == class_values[i],:]
    187     ax.scatter(good_x[:, 0], good_x[:, 1],
--> 188                s=dot_w, c=color_map[i],
    189                marker=markers[i],
    190                alpha=colors['scatter_marker_alpha'],
    191                edgecolors=colors['scatter_edge'],
    192                lw=.5)
    193     # Show misclassified markers (can't have alpha per marker so do in 2 calls)
    194     bad_x = x_[class_X_pred[i] != class_values[i],:]

KeyError: 0

data:

X = np.array([[2.], [1.], [3.], [3.], [1.], [1.], [2.], [3.], [1.], [2.], [3.], [1.], [1.], [1.], [1.], [1.], [3.], [2.], [3.], [2.], [1.], [1.], [2.], [1.], [1.], [1.], [1.], [1.], [3.], [3.], [1.], [1.], [3.], [2.], [1.], [1.], [2.], [1.], [3.], [1.], [1.], [1.], [2.], [1.], [1.], [3.], [2.], [2.], [3.], [3.], [3.], [3.], [1.], [1.], [1.], [1.], [3.], [1.], [3.]]) y = np.array([3, 1, 4, 1, 4, 2, 1, 1, 4, 3, 4, 4, 1, 2, 3, 2, 2, 3, 3, 2, 3, 3, 2, 2, 3, 2, 4, 2, 2, 1, 4, 2, 3, 4, 1, 4, 4, 1, 2, 3, 3, 3, 4, 1, 1, 1, 4, 4, 3, 2, 4, 1, 1, 4, 2, 3, 1, 1, 2])

bridgesra avatar Feb 17 '22 03:02 bridgesra

if it's already fixed, could you close the issue ? thanks

tlapusan avatar Feb 17 '22 08:02 tlapusan