hummingbird
hummingbird copied to clipboard
Support for hinge loss on Sklearn SGDClassifier
Ref error message from hummingbird-ml:
AssertionError: predict_proba for linear models currently only support {'modified_huber', 'squared_hinge', 'log'}. (Given hinge). Please fill an issue at https://github.com/microsoft/hummingbird
Simple enough to get around using squared_hinge, but it yields a significant performance loss compared to hinge, at least for a single epoch.
Hummingbird version: '0.4.5'
Ran on Python 3.9.12 (main, Jun 1 2022, 11:38:51) [GCC 7.5.0] :: Anaconda, Inc. on linux.
Simple to reproduce, see code below:
from sklearn.linear_model import SGDClassifier
from sklearn import datasets
from sklearn import metrics
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from hummingbird.ml import convert, load
#! Built on the sklearn intro example: https://scikit-learn.org/stable/tutorial/basic/tutorial.html
# Data loading
# iris = datasets.load_iris()
digits = datasets.load_digits()
# Data engineering
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# Training split
X_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.90, shuffle=False)
# Model definition and parameter selection
clf = SGDClassifier(loss="hinge")
# Model training
clf.fit(X_train, y_train)
model = convert(clf, "pytorch")
# Model prediction
# predicted = clf.predict(X_test)
predicted = model.predict(X_test)
model.save("hb_model")
model = load("hb_model")
# Model evaluation
# Classification report
print(f"Classification report for classifier {clf}:\n" f"{metrics.classification_report(y_test, predicted)}\n")
# Confusion matrix - plot
disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)
disp.figure_.suptitle("Confusion Matrix")
print(f"Confusion matrix:\n{disp.confusion_matrix}")
plt.show()
# Write results to file
report = metrics.classification_report(y_test, predicted)
Thanks, we'll look into what needs to be done to add this param. I don't remember there being a technical reason for why it wasn't added originally
Awesome!
On Tue, 13 Feb 2024, 19:16 Karla Saur, @.***> wrote:
Closed #626 https://github.com/microsoft/hummingbird/issues/626 as completed via #758 https://github.com/microsoft/hummingbird/pull/758.
— Reply to this email directly, view it on GitHub https://github.com/microsoft/hummingbird/issues/626#event-11792477770, or unsubscribe https://github.com/notifications/unsubscribe-auth/AG7IBS2BMLBZ3B4EVJ5RT73YTOUWTAVCNFSM57TSNPN2U5DIOJSWCZC7NNSXTWQAEJEXG43VMVCXMZLOORHG65DJMZUWGYLUNFXW4OZRGE3TSMRUG43TONZQ . You are receiving this because you authored the thread.Message ID: @.***>