hummingbird icon indicating copy to clipboard operation
hummingbird copied to clipboard

Support for hinge loss on Sklearn SGDClassifier

Open Economax opened this issue 2 years ago • 1 comments

Ref error message from hummingbird-ml:

AssertionError: predict_proba for linear models currently only support {'modified_huber', 'squared_hinge', 'log'}. (Given hinge). Please fill an issue at https://github.com/microsoft/hummingbird

Simple enough to get around using squared_hinge, but it yields a significant performance loss compared to hinge, at least for a single epoch.

Hummingbird version: '0.4.5'

Ran on Python 3.9.12 (main, Jun 1 2022, 11:38:51) [GCC 7.5.0] :: Anaconda, Inc. on linux.

Simple to reproduce, see code below:

from sklearn.linear_model import SGDClassifier
from sklearn import datasets
from sklearn import metrics
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from hummingbird.ml import convert, load

#! Built on the sklearn intro example: https://scikit-learn.org/stable/tutorial/basic/tutorial.html

# Data loading
# iris = datasets.load_iris()
digits = datasets.load_digits()

# Data engineering
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))


# Training split
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.90, shuffle=False)


# Model definition and parameter selection
clf = SGDClassifier(loss="hinge")


# Model training
clf.fit(X_train, y_train)

model = convert(clf, "pytorch")


# Model prediction
# predicted = clf.predict(X_test)

predicted = model.predict(X_test)

model.save("hb_model")

model = load("hb_model")


# Model evaluation

# Classification report
print(f"Classification report for classifier {clf}:\n" f"{metrics.classification_report(y_test, predicted)}\n")

# Confusion matrix - plot
disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)
disp.figure_.suptitle("Confusion Matrix")
print(f"Confusion matrix:\n{disp.confusion_matrix}")

plt.show()

# Write results to file

report = metrics.classification_report(y_test, predicted)

Economax avatar Aug 25 '22 15:08 Economax

Thanks, we'll look into what needs to be done to add this param. I don't remember there being a technical reason for why it wasn't added originally

ksaur avatar Aug 26 '22 22:08 ksaur

Awesome!

On Tue, 13 Feb 2024, 19:16 Karla Saur, @.***> wrote:

Closed #626 https://github.com/microsoft/hummingbird/issues/626 as completed via #758 https://github.com/microsoft/hummingbird/pull/758.

— Reply to this email directly, view it on GitHub https://github.com/microsoft/hummingbird/issues/626#event-11792477770, or unsubscribe https://github.com/notifications/unsubscribe-auth/AG7IBS2BMLBZ3B4EVJ5RT73YTOUWTAVCNFSM57TSNPN2U5DIOJSWCZC7NNSXTWQAEJEXG43VMVCXMZLOORHG65DJMZUWGYLUNFXW4OZRGE3TSMRUG43TONZQ . You are receiving this because you authored the thread.Message ID: @.***>

Economax avatar Feb 13 '24 18:02 Economax