coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

Invalid prediction values for multi-class decision tree ensembles converted from scikit-learn

Open ezheidtmann opened this issue 3 years ago • 0 comments

🐞Describing the bug

Multiclass decision tree ensembles (i.e. Random Forest) converted from sklearn don't produce predictions matching those obtained from predict_proba(). It appears that the resulting probabilities are not scaled by the number of trees in the ensemble.

To Reproduce

from sklearn.cross_validation import train_test_split
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
import coremltools as ct

N_TREES = 13
N_CLASSES = 3

X, y = make_classification(
    n_samples=1000,
    n_features=8,
    n_informative=4,
    n_redundant=0,
    n_classes=N_CLASSES,
    random_state=0,
    shuffle=True,
)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=0
)

clf = RandomForestClassifier(max_depth=2, random_state=0, n_estimators=N_TREES)
clf.fit(X_train, y_train)

coreml_model = ct.converters.sklearn.convert(clf, "input")
sk_results = clf.predict_proba(X_test)
for index, row in enumerate(X_test):
    coreml_result = coreml_model.predict({"input": row})
    sk_result = sk_results[index]
    for class_i in range(N_CLASSES):
        assert (
            abs(
                coreml_result["classProbability"][class_i]
                - N_TREES * sk_result[class_i]
            )
            < 1e-7
        )

System environment

  • coremltools version: 5.2.0
  • scikit-learn: 0.19.2
  • OS (e.g. MacOS version or Linux type): macOS 12.3.1

Additional context

It looks like the scaling parameter is not used for multi class decision trees, but IS used for two-class decision trees. Compare line 39 to line 28 in _tree_ensemble.py:

https://github.com/apple/coremltools/blob/aeb3e8e50dc3af06a8a8988fedcb931c4344dfa3/coremltools/converters/sklearn/_tree_ensemble.py#L25-L43

ezheidtmann avatar Jun 23 '22 18:06 ezheidtmann