pyGAM icon indicating copy to clipboard operation
pyGAM copied to clipboard

[WIP] add decision function to LogisticGAM, for compat with sklearn multiclass

Open dswah opened this issue 5 years ago • 9 comments

closes https://github.com/dswah/pyGAM/issues/196

  • [x] add decision_function method to the LogisticGAM class for compatibility with sklearn's OneVsRestClassifier class

  • [x] test

  • [ ] refactor terms so that all argument processing and sanitation is done during the fit() method

dswah avatar Oct 03 '18 15:10 dswah

Codecov Report

Merging #213 into master will increase coverage by 0.06%. The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #213      +/-   ##
==========================================
+ Coverage   94.86%   94.92%   +0.06%     
==========================================
  Files          22       22              
  Lines        3056     3074      +18     
==========================================
+ Hits         2899     2918      +19     
+ Misses        157      156       -1
Impacted Files Coverage Δ
pygam/tests/test_GAM_methods.py 100% <100%> (ø) :arrow_up:
pygam/pygam.py 94.9% <100%> (+0.14%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 1c43c05...27741dd. Read the comment docs.

codecov[bot] avatar Oct 03 '18 15:10 codecov[bot]

Can't use predict_proba with #196 example

903124 avatar Oct 24 '18 06:10 903124

Can't use predict_proba with #196 example

@903124 ahh good point. this is because the predict_proba api for pygam didnt conform to sklearn. i am making an update.

dswah avatar Oct 24 '18 14:10 dswah

Predict_proba is now working properly but functional terms is (still) broken. It only returns the predicted class of the first item if a functional term is given.

903124 avatar Oct 24 '18 16:10 903124

@903124 hmm thats interesting, but i dont completely understand.

can you please show me an example?

dswah avatar Oct 24 '18 17:10 dswah

base_estimator = LogisticGAM(n_splines=5)
ensemble = OneVsRestClassifier(base_estimator, n_jobs=1)
ensemble.fit(X_train, y_train)

OneVsRestClassifier(estimator=LogisticGAM(callbacks=['deviance', 'diffs', 'accuracy'],
   fit_intercept=True, max_iter=100, n_splines=5, terms='auto',
   tol=0.0001, verbose=False),
          n_jobs=1)

ensemble.predict_proba(X_test)
array([[1.19215720e-01, 9.04500103e-02, 1.77326087e-05, ...,
        2.42330347e-03, 1.32568439e-01, 2.24134228e-01],
       [1.66961024e-01, 1.26811544e-01, 1.92006828e-04, ...,
        2.29616707e-03, 1.00314986e-01, 7.13239544e-02],
       [3.84291300e-02, 3.18707670e-02, 4.24383742e-10, ...,
        2.49137619e-03, 4.38926722e-01, 3.26176537e-01],
       ...,
       [9.50068415e-02, 8.09269049e-02, 2.80842523e-06, ...,
        2.86022038e-03, 1.46092712e-01, 2.74045716e-01],
       [1.46784609e-01, 1.06025123e-01, 2.14491296e-06, ...,
        2.60290652e-03, 1.48957364e-01, 1.20364635e-02],
       [1.52544943e-02, 1.36320357e-02, 1.01640876e-11, ...,
        1.68127415e-03, 3.55414339e-01, 5.55598935e-01]])
base_estimator = LogisticGAM(s(0)+f(1)+f(2)+f(3)+f(4)+s(5), n_splines=5)
ensemble = OneVsRestClassifier(base_estimator, n_jobs=1)
ensemble.fit(X_train, y_train)

OneVsRestClassifier(estimator=LogisticGAM(callbacks=['deviance', 'diffs', 'accuracy'],
   fit_intercept=True, max_iter=100, n_splines=5,
   terms=s(0) + f(1) + f(2) + f(3) + f(4) + s(5), tol=0.0001,
   verbose=False),
          n_jobs=1)

ensemble.predict_proba(X_test)
array([[0.10159551, 0.08261063, 0.00094507, ..., 0.00239048, 0.21650545,
        0.23143207],
       [0.10159551, 0.08261063, 0.00094507, ..., 0.00239048, 0.21650545,
        0.23143207],
       [0.10159551, 0.08261063, 0.00094507, ..., 0.00239048, 0.21650545,
        0.23143207],
       ...,
       [0.10159551, 0.08261063, 0.00094507, ..., 0.00239048, 0.21650545,
        0.23143207],
       [0.10159551, 0.08261063, 0.00094507, ..., 0.00239048, 0.21650545,
        0.23143207],
       [0.10159551, 0.08261063, 0.00094507, ..., 0.00239048, 0.21650545,
        0.23143207]])

903124 avatar Oct 24 '18 17:10 903124

@903124 thanks for the example. i see what you mean

dswah avatar Oct 25 '18 12:10 dswah

@903124 thanks for your comments!

i took a look inside, and it appears to be a complex fix.

scikit-learn defers any processing of arguments until the fix() method is called, but i violated that rule with pygam's terms.

fixing this is going to require a little bit more effort to defer all processing of arguments to the terms.

dswah avatar Oct 26 '18 12:10 dswah

Is this still going to be merged into master? Would love to try it out!

Excidion avatar Feb 04 '20 10:02 Excidion