imodels
imodels copied to clipboard
HSTree Multiclass Classification Support
Does HSTree support multiclass classification problems with RandomForest / ExtraTrees as the estimator?
From my initial tests it appears buggy. Calling predict_proba
with the final model results in lots of NaN predictions, along with warnings during training such as:
/Users/neerick/workspace/virtual/autogluon/lib/python3.8/site-packages/imodels/tree/hierarchical_shrinkage.py:87: RuntimeWarning: invalid value encountered in double_scalars
val = tree.value[i][0, 1] / (tree.value[i][0, 0] + tree.value[i][0, 1]) # binary classification
If helpful I can try to create a reproducible example.
Here is an example result comparing with sklearn default RF (_og_
) with accuracy metric. Because HSTree returns many NaN predictions, the scores are very low.
One observation is the scores get worse the more trees there are in HSTree forests. I'd guess the likelihood of returning a NaN result is increasing with the number of trees.
model score_test score_val pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 RandomForest_og_n300 0.711651 0.723618 0.985573 0.050956 0.519926 0.985573 0.050956 0.519926 1 True 1
1 RandomForest_og_n100 0.710154 0.748744 0.453769 0.019050 0.170951 0.453769 0.019050 0.170951 1 True 2
2 WeightedEnsemble_L2 0.710154 0.748744 0.464755 0.019376 0.295161 0.010986 0.000326 0.124210 2 True 36
3 RandomForest_og_n40 0.700636 0.698492 0.193009 0.010738 0.088012 0.193009 0.010738 0.088012 1 True 3
4 RandomForest_og_n20 0.692039 0.698492 0.103616 0.007549 0.057396 0.103616 0.007549 0.057396 1 True 4
5 RandomForest_og_n10 0.674165 0.688442 0.075296 0.006166 0.041720 0.075296 0.006166 0.041720 1 True 5
6 RandomForest_hs=10_n10 0.521949 0.537688 0.070260 0.005246 0.082384 0.070260 0.005246 0.082384 1 True 15
7 RandomForest_hs=50_n10 0.520839 0.517588 0.075151 0.004875 0.071219 0.075151 0.004875 0.071219 1 True 20
8 RandomForest_hs=0.1_n10 0.520796 0.537688 0.074070 0.005233 0.093299 0.074070 0.005233 0.093299 1 True 35
9 RandomForest_hs=1_n10 0.520692 0.542714 0.077687 0.005690 0.075061 0.077687 0.005690 0.075061 1 True 10
10 RandomForest_hs=100_n10 0.519246 0.517588 0.075059 0.006019 0.082536 0.075059 0.006019 0.082536 1 True 25
11 RandomForest_hs=500_n10 0.488877 0.517588 0.072145 0.005125 0.072223 0.072145 0.005125 0.072223 1 True 30
12 RandomForest_hs=1_n20 0.485125 0.472362 0.113002 0.006484 0.123639 0.113002 0.006484 0.123639 1 True 9
13 RandomForest_hs=0.1_n20 0.485005 0.472362 0.111342 0.005953 0.146246 0.111342 0.005953 0.146246 1 True 34
14 RandomForest_hs=10_n20 0.484833 0.482412 0.104076 0.006577 0.131909 0.104076 0.006577 0.131909 1 True 14
15 RandomForest_hs=50_n20 0.482896 0.482412 0.115057 0.006263 0.130512 0.115057 0.006263 0.130512 1 True 19
16 RandomForest_hs=100_n20 0.480840 0.482412 0.108625 0.006045 0.135224 0.108625 0.006045 0.135224 1 True 24
17 RandomForest_hs=500_n20 0.458035 0.467337 0.108658 0.006302 0.123907 0.108658 0.006302 0.123907 1 True 29
18 RandomForest_hs=1_n40 0.451434 0.467337 0.185129 0.010619 0.210639 0.185129 0.010619 0.210639 1 True 8
19 RandomForest_hs=0.1_n40 0.451382 0.467337 0.170597 0.009024 0.244322 0.170597 0.009024 0.244322 1 True 33
20 RandomForest_hs=10_n40 0.451322 0.467337 0.173382 0.009955 0.210795 0.173382 0.009955 0.210795 1 True 13
21 RandomForest_hs=50_n40 0.450350 0.467337 0.170041 0.008673 0.236081 0.170041 0.008673 0.236081 1 True 18
22 RandomForest_hs=100_n40 0.449119 0.467337 0.169396 0.010918 0.226784 0.169396 0.010918 0.226784 1 True 23
23 RandomForest_hs=500_n40 0.435832 0.472362 0.162881 0.009256 0.202447 0.162881 0.009256 0.202447 1 True 28
24 RandomForest_hs=1_n100 0.420419 0.452261 0.442328 0.017688 0.480776 0.442328 0.017688 0.480776 1 True 7
25 RandomForest_hs=0.1_n100 0.420411 0.452261 0.354523 0.018247 0.548557 0.354523 0.018247 0.548557 1 True 32
26 RandomForest_hs=10_n100 0.419981 0.452261 0.355097 0.017487 0.469547 0.355097 0.017487 0.469547 1 True 12
27 RandomForest_hs=50_n100 0.419034 0.447236 0.344341 0.021125 0.465810 0.344341 0.021125 0.465810 1 True 17
28 RandomForest_hs=100_n100 0.418672 0.447236 0.372041 0.018402 0.477048 0.372041 0.018402 0.477048 1 True 22
29 RandomForest_hs=500_n100 0.415256 0.457286 0.338696 0.017128 0.492786 0.338696 0.017128 0.492786 1 True 27
30 RandomForest_hs=0.1_n300 0.381049 0.391960 0.967061 0.045552 1.533075 0.967061 0.045552 1.533075 1 True 31
31 RandomForest_hs=10_n300 0.381049 0.391960 1.109062 0.054005 1.442369 1.109062 0.054005 1.442369 1 True 11
32 RandomForest_hs=1_n300 0.381040 0.391960 1.677277 0.055421 2.346773 1.677277 0.055421 2.346773 1 True 6
33 RandomForest_hs=50_n300 0.380945 0.391960 0.889030 0.053650 1.320377 0.889030 0.053650 1.320377 1 True 16
34 RandomForest_hs=100_n300 0.380885 0.391960 1.031198 0.045266 1.254918 1.031198 0.045266 1.254918 1 True 21
35 RandomForest_hs=500_n300 0.380816 0.391960 0.948715 0.050209 1.266396 0.948715 0.050209 1.266396 1 True 26
Hi Nick, you're right this is currently not supported (the shrink function is written only for univariate regression/binary classification and misbehaves with multiple classes). It's a pretty straightforward extension though and we can get to it soon!
@aagarwal1996 @yanshuotan
That would be amazing, thanks!
@Innixma should work now :)
Awesome, thank you so much!