imodels icon indicating copy to clipboard operation
imodels copied to clipboard

HSTree Multiclass Classification Support

Open Innixma opened this issue 1 year ago • 2 comments

Does HSTree support multiclass classification problems with RandomForest / ExtraTrees as the estimator?

From my initial tests it appears buggy. Calling predict_proba with the final model results in lots of NaN predictions, along with warnings during training such as:

/Users/neerick/workspace/virtual/autogluon/lib/python3.8/site-packages/imodels/tree/hierarchical_shrinkage.py:87: RuntimeWarning: invalid value encountered in double_scalars
  val = tree.value[i][0, 1] / (tree.value[i][0, 0] + tree.value[i][0, 1])  # binary classification

If helpful I can try to create a reproducible example.

Here is an example result comparing with sklearn default RF (_og_) with accuracy metric. Because HSTree returns many NaN predictions, the scores are very low.

One observation is the scores get worse the more trees there are in HSTree forests. I'd guess the likelihood of returning a NaN result is increasing with the number of trees.

                       model  score_test  score_val  pred_time_test  pred_time_val  fit_time  pred_time_test_marginal  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0       RandomForest_og_n300    0.711651   0.723618        0.985573       0.050956  0.519926                 0.985573                0.050956           0.519926            1       True          1
1       RandomForest_og_n100    0.710154   0.748744        0.453769       0.019050  0.170951                 0.453769                0.019050           0.170951            1       True          2
2        WeightedEnsemble_L2    0.710154   0.748744        0.464755       0.019376  0.295161                 0.010986                0.000326           0.124210            2       True         36
3        RandomForest_og_n40    0.700636   0.698492        0.193009       0.010738  0.088012                 0.193009                0.010738           0.088012            1       True          3
4        RandomForest_og_n20    0.692039   0.698492        0.103616       0.007549  0.057396                 0.103616                0.007549           0.057396            1       True          4
5        RandomForest_og_n10    0.674165   0.688442        0.075296       0.006166  0.041720                 0.075296                0.006166           0.041720            1       True          5
6     RandomForest_hs=10_n10    0.521949   0.537688        0.070260       0.005246  0.082384                 0.070260                0.005246           0.082384            1       True         15
7     RandomForest_hs=50_n10    0.520839   0.517588        0.075151       0.004875  0.071219                 0.075151                0.004875           0.071219            1       True         20
8    RandomForest_hs=0.1_n10    0.520796   0.537688        0.074070       0.005233  0.093299                 0.074070                0.005233           0.093299            1       True         35
9      RandomForest_hs=1_n10    0.520692   0.542714        0.077687       0.005690  0.075061                 0.077687                0.005690           0.075061            1       True         10
10   RandomForest_hs=100_n10    0.519246   0.517588        0.075059       0.006019  0.082536                 0.075059                0.006019           0.082536            1       True         25
11   RandomForest_hs=500_n10    0.488877   0.517588        0.072145       0.005125  0.072223                 0.072145                0.005125           0.072223            1       True         30
12     RandomForest_hs=1_n20    0.485125   0.472362        0.113002       0.006484  0.123639                 0.113002                0.006484           0.123639            1       True          9
13   RandomForest_hs=0.1_n20    0.485005   0.472362        0.111342       0.005953  0.146246                 0.111342                0.005953           0.146246            1       True         34
14    RandomForest_hs=10_n20    0.484833   0.482412        0.104076       0.006577  0.131909                 0.104076                0.006577           0.131909            1       True         14
15    RandomForest_hs=50_n20    0.482896   0.482412        0.115057       0.006263  0.130512                 0.115057                0.006263           0.130512            1       True         19
16   RandomForest_hs=100_n20    0.480840   0.482412        0.108625       0.006045  0.135224                 0.108625                0.006045           0.135224            1       True         24
17   RandomForest_hs=500_n20    0.458035   0.467337        0.108658       0.006302  0.123907                 0.108658                0.006302           0.123907            1       True         29
18     RandomForest_hs=1_n40    0.451434   0.467337        0.185129       0.010619  0.210639                 0.185129                0.010619           0.210639            1       True          8
19   RandomForest_hs=0.1_n40    0.451382   0.467337        0.170597       0.009024  0.244322                 0.170597                0.009024           0.244322            1       True         33
20    RandomForest_hs=10_n40    0.451322   0.467337        0.173382       0.009955  0.210795                 0.173382                0.009955           0.210795            1       True         13
21    RandomForest_hs=50_n40    0.450350   0.467337        0.170041       0.008673  0.236081                 0.170041                0.008673           0.236081            1       True         18
22   RandomForest_hs=100_n40    0.449119   0.467337        0.169396       0.010918  0.226784                 0.169396                0.010918           0.226784            1       True         23
23   RandomForest_hs=500_n40    0.435832   0.472362        0.162881       0.009256  0.202447                 0.162881                0.009256           0.202447            1       True         28
24    RandomForest_hs=1_n100    0.420419   0.452261        0.442328       0.017688  0.480776                 0.442328                0.017688           0.480776            1       True          7
25  RandomForest_hs=0.1_n100    0.420411   0.452261        0.354523       0.018247  0.548557                 0.354523                0.018247           0.548557            1       True         32
26   RandomForest_hs=10_n100    0.419981   0.452261        0.355097       0.017487  0.469547                 0.355097                0.017487           0.469547            1       True         12
27   RandomForest_hs=50_n100    0.419034   0.447236        0.344341       0.021125  0.465810                 0.344341                0.021125           0.465810            1       True         17
28  RandomForest_hs=100_n100    0.418672   0.447236        0.372041       0.018402  0.477048                 0.372041                0.018402           0.477048            1       True         22
29  RandomForest_hs=500_n100    0.415256   0.457286        0.338696       0.017128  0.492786                 0.338696                0.017128           0.492786            1       True         27
30  RandomForest_hs=0.1_n300    0.381049   0.391960        0.967061       0.045552  1.533075                 0.967061                0.045552           1.533075            1       True         31
31   RandomForest_hs=10_n300    0.381049   0.391960        1.109062       0.054005  1.442369                 1.109062                0.054005           1.442369            1       True         11
32    RandomForest_hs=1_n300    0.381040   0.391960        1.677277       0.055421  2.346773                 1.677277                0.055421           2.346773            1       True          6
33   RandomForest_hs=50_n300    0.380945   0.391960        0.889030       0.053650  1.320377                 0.889030                0.053650           1.320377            1       True         16
34  RandomForest_hs=100_n300    0.380885   0.391960        1.031198       0.045266  1.254918                 1.031198                0.045266           1.254918            1       True         21
35  RandomForest_hs=500_n300    0.380816   0.391960        0.948715       0.050209  1.266396                 0.948715                0.050209           1.266396            1       True         26

Innixma avatar Jul 30 '22 23:07 Innixma

Hi Nick, you're right this is currently not supported (the shrink function is written only for univariate regression/binary classification and misbehaves with multiple classes). It's a pretty straightforward extension though and we can get to it soon!

@aagarwal1996 @yanshuotan

csinva avatar Jul 31 '22 01:07 csinva

That would be amazing, thanks!

Innixma avatar Jul 31 '22 18:07 Innixma

@Innixma should work now :)

csinva avatar Aug 25 '22 01:08 csinva

Awesome, thank you so much!

Innixma avatar Aug 25 '22 04:08 Innixma