imodels
imodels copied to clipboard
Speeding up HS with LOOCV
Hello, thanks again for the great library!
I'm interested in applying HSTree to RandomForest and ExtraTrees models.
According to the documentation, I can specify a random forest object in the estimator_
argument, however this raises an error when I try to fit:
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier()
import imodels
from imodels.tree.hierarchical_shrinkage import HSTreeClassifier
model = HSTreeClassifier(estimator_=model)
model = model.fit(X, y)
File "/Users/neerick/workspace/code/autogluon/tabular/src/autogluon/tabular/models/rf/rf_model.py", line 232, in _fit
model = model.fit(X, y, sample_weight=sample_weight)
File "/Users/neerick/workspace/virtual/autogluon/lib/python3.8/site-packages/imodels/tree/hierarchical_shrinkage.py", line 64, in fit
self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
AttributeError: 'RandomForestClassifier' object has no attribute 'tree_'
https://github.com/csinva/imodels/blob/master/imodels/tree/hierarchical_shrinkage.py
I don't see any tutorial / documentation for creating a random forest or extra trees model via HSTree, but the paper mentions that this is possible and gets good results. I was wondering if the maintainers could point me to an example or tutorial on this.
Thanks!
Thanks for the question Nick! Indeed we should have publicly released this with some documentation by now - will get to it shortly!
@aagarwal1996 - can you add in the HS + RF / ExtraTrees code with a little doc?
Best, Chandan
Just fixed it in this commit :)
Will have to bump the version to 1.3.3
get it to work, but the code you were using above should work now. Also added a little snippet to the doc here.
Cheers, Chandan
Incredible. It works now and is showing very strong results on the Adult dataset (will plan to test more). For example, num_estimators=40 with HSTreeClassifierCV is getting significantly better test scores than RandomForestClassifier(num_estimators=300). The API is perfect, and I can add it in AutoGluon with <10 lines of code.
Question I had related to the ICML oral presentation / paper, I recall the mention of efficient leave-one-out CV to optimize the reg_param. Is this implemented or do I need to pay the cost of HSTreeClassifierCV to get the optimal reg_param value?
Glad to hear it's working for you 😄
This isn't currently implemented but should be easy to do - we'll get it done soon, especially if you're putting this into AutoGluon!
@aagarwal1996 @yanshuotan @OmerRonen - someone wanna take this on?
Here is a hyperparameter grid result on Adult sampled to 2000 rows of training data, sorted by test score (AUC)
og = vanilla sklearn hs=10_n300 = reg_param=10,n_estimators=300
Amazingly, HS with reg_param 50, 100, and 500 and only 10 n_estimators get a better test score than OG with 1000 n_estimators!
model score_test score_val pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 RandomForest_hs=50_n300 0.906977 0.904691 0.095684 0.050235 2.733138 0.095684 0.050235 2.733138 1 True 20
1 RandomForest_hs=50_n100 0.906954 0.905086 0.037076 0.087729 0.922209 0.037076 0.087729 0.922209 1 True 21
2 RandomForest_hs=50_n1000 0.906939 0.905020 0.287926 0.150001 9.144297 0.287926 0.150001 9.144297 1 True 19
3 RandomForest_hs=100_n300 0.906854 0.905185 0.098557 0.050372 2.767945 0.098557 0.050372 2.767945 1 True 26
4 RandomForest_hs=100_n1000 0.906818 0.905020 0.297747 0.155764 9.122136 0.297747 0.155764 9.122136 1 True 25
5 RandomForest_hs=100_n100 0.906746 0.904823 0.040018 0.022333 0.951915 0.040018 0.022333 0.951915 1 True 27
6 WeightedEnsemble_L2 0.906464 0.906106 0.822571 0.493818 25.186706 0.003169 0.000933 0.448992 2 True 43
7 RandomForest_hs=50_n40 0.905625 0.903869 0.019060 0.012893 0.386387 0.019060 0.012893 0.386387 1 True 22
8 RandomForest_hs=100_n40 0.905619 0.904691 0.019219 0.015171 0.412162 0.019219 0.015171 0.412162 1 True 28
9 RandomForest_hs=10_n100 0.904859 0.904955 0.038016 0.021933 0.915743 0.038016 0.021933 0.915743 1 True 15
10 RandomForest_hs=10_n300 0.904659 0.905481 0.098096 0.049020 2.781127 0.098096 0.049020 2.781127 1 True 14
11 RandomForest_hs=10_n1000 0.904594 0.904659 0.296132 0.150198 9.204939 0.296132 0.150198 9.204939 1 True 13
12 RandomForest_hs=100_n20 0.904273 0.901533 0.013146 0.013468 0.245283 0.013146 0.013468 0.245283 1 True 29
13 RandomForest_hs=50_n20 0.903543 0.901336 0.012793 0.009559 0.201715 0.012793 0.009559 0.201715 1 True 23
14 RandomForest_hs=500_n1000 0.903443 0.903310 0.292724 0.153926 9.449500 0.292724 0.153926 9.449500 1 True 31
15 RandomForest_hs=500_n300 0.903273 0.904033 0.098899 0.058785 2.789253 0.098899 0.058785 2.789253 1 True 32
16 RandomForest_hs=10_n40 0.902558 0.899888 0.020626 0.015472 0.387508 0.020626 0.015472 0.387508 1 True 16
17 RandomForest_hs=500_n100 0.902558 0.904066 0.055166 0.022750 0.920178 0.055166 0.022750 0.920178 1 True 33
18 RandomForest_hs=500_n40 0.901187 0.901928 0.038185 0.013044 0.375549 0.038185 0.013044 0.375549 1 True 34
19 RandomForest_hs=500_n20 0.901066 0.899888 0.019732 0.012050 0.197903 0.019732 0.012050 0.197903 1 True 35
20 RandomForest_hs=100_n10 0.901057 0.897684 0.010498 0.008219 0.132994 0.010498 0.008219 0.132994 1 True 30
21 RandomForest_hs=50_n10 0.900010 0.896697 0.010349 0.007744 0.106932 0.010349 0.007744 0.106932 1 True 24
22 RandomForest_hs=1_n100 0.899934 0.903639 0.038109 0.021423 0.925629 0.038109 0.021423 0.925629 1 True 9
23 RandomForest_hs=1_n1000 0.899916 0.904823 0.292249 0.153203 9.213101 0.292249 0.153203 9.213101 1 True 7
24 RandomForest_hs=500_n10 0.899908 0.899296 0.011997 0.008104 0.118062 0.011997 0.008104 0.118062 1 True 36
25 RandomForest_hs=1_n300 0.899902 0.903704 0.097325 0.054082 2.725383 0.097325 0.054082 2.725383 1 True 8
26 RandomForest_hs=10_n20 0.898988 0.899197 0.017777 0.009672 0.217858 0.017777 0.009672 0.217858 1 True 17
27 RandomForest_hs=0.1_n100 0.898405 0.901961 0.038886 0.019792 1.188789 0.038886 0.019792 1.188789 1 True 39
28 RandomForest_hs=0.1_n1000 0.898272 0.903869 0.300892 0.151734 11.785147 0.300892 0.151734 11.785147 1 True 37
29 RandomForest_hs=0.1_n300 0.898254 0.903211 0.102140 0.059019 3.490166 0.102140 0.059019 3.490166 1 True 38
30 RandomForest_og_n1000 0.898019 0.903721 0.300054 0.163219 2.121544 0.300054 0.163219 2.121544 1 True 1
31 RandomForest_og_n300 0.897952 0.903145 0.098421 0.053610 0.375765 0.098421 0.053610 0.375765 1 True 2
32 RandomForest_og_n100 0.897658 0.901533 0.044796 0.021702 0.158356 0.044796 0.021702 0.158356 1 True 3
33 RandomForest_hs=1_n40 0.895661 0.894657 0.019914 0.014062 0.378873 0.019914 0.014062 0.378873 1 True 10
34 RandomForest_hs=10_n10 0.894480 0.893670 0.011885 0.008050 0.107222 0.011885 0.008050 0.107222 1 True 18
35 RandomForest_hs=0.1_n40 0.894129 0.893604 0.018879 0.011820 0.469591 0.018879 0.011820 0.469591 1 True 40
36 RandomForest_og_n40 0.891790 0.891910 0.018995 0.014379 0.073364 0.018995 0.014379 0.073364 1 True 4
37 RandomForest_hs=1_n20 0.890021 0.896532 0.013560 0.010845 0.239654 0.013560 0.010845 0.239654 1 True 11
38 RandomForest_hs=0.1_n20 0.888886 0.896105 0.013955 0.010772 0.243502 0.013955 0.010772 0.243502 1 True 41
39 RandomForest_og_n20 0.883450 0.890907 0.013461 0.010590 0.045405 0.013461 0.010590 0.045405 1 True 5
40 RandomForest_hs=1_n10 0.883443 0.888012 0.010171 0.008007 0.111472 0.010171 0.008007 0.111472 1 True 12
41 RandomForest_hs=0.1_n10 0.882777 0.887584 0.009979 0.007687 0.138142 0.009979 0.007687 0.138142 1 True 42
42 RandomForest_og_n10 0.869862 0.874079 0.010571 0.008906 0.030527 0.010571 0.008906 0.030527 1 True 6
Similar findings with Extra Trees!
model score_test score_val pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 ExtraTrees_hs=50_n1000 0.898862 0.895578 0.360272 0.152184 12.766335 0.360272 0.152184 12.766335 1 True 19
1 ExtraTrees_hs=10_n1000 0.898727 0.898638 0.375004 0.154642 13.135831 0.375004 0.154642 13.135831 1 True 13
2 ExtraTrees_hs=50_n300 0.898286 0.896006 0.126808 0.057546 3.828264 0.126808 0.057546 3.828264 1 True 20
3 ExtraTrees_hs=50_n100 0.898091 0.895545 0.050484 0.020662 1.293077 0.050484 0.020662 1.293077 1 True 21
4 ExtraTrees_hs=10_n300 0.897934 0.898802 0.129800 0.050548 3.923364 0.129800 0.050548 3.923364 1 True 14
5 ExtraTrees_hs=10_n100 0.897840 0.897848 0.050855 0.021138 1.308959 0.050855 0.021138 1.308959 1 True 15
6 ExtraTrees_hs=100_n1000 0.897263 0.894295 0.363426 0.148167 12.866699 0.363426 0.148167 12.866699 1 True 25
7 ExtraTrees_hs=100_n300 0.896848 0.894756 0.129441 0.049175 3.806397 0.129441 0.049175 3.806397 1 True 26
8 ExtraTrees_hs=50_n40 0.896566 0.893868 0.024298 0.012653 0.611947 0.024298 0.012653 0.611947 1 True 22
9 WeightedEnsemble_L2 0.896430 0.900053 1.370685 0.559450 50.761817 0.003386 0.000914 0.446805 2 True 43
10 ExtraTrees_hs=100_n100 0.896420 0.893868 0.046713 0.026253 1.349054 0.046713 0.026253 1.349054 1 True 27
11 ExtraTrees_hs=10_n40 0.895319 0.892486 0.025234 0.012230 0.530321 0.025234 0.012230 0.530321 1 True 16
12 ExtraTrees_hs=100_n40 0.895164 0.893341 0.022953 0.013198 0.523787 0.022953 0.013198 0.523787 1 True 28
13 ExtraTrees_hs=1_n1000 0.894024 0.898441 0.351593 0.148842 13.060638 0.351593 0.148842 13.060638 1 True 7
14 ExtraTrees_hs=1_n300 0.892950 0.897848 0.125843 0.050800 3.828052 0.125843 0.050800 3.828052 1 True 8
15 ExtraTrees_hs=1_n100 0.892657 0.895743 0.052110 0.021437 1.310609 0.052110 0.021437 1.310609 1 True 9
16 ExtraTrees_hs=0.1_n1000 0.891799 0.897914 0.385060 0.153704 16.367127 0.385060 0.153704 16.367127 1 True 37
17 ExtraTrees_og_n1000 0.891447 0.897816 0.372907 0.220728 2.279349 0.372907 0.220728 2.279349 1 True 1
18 ExtraTrees_hs=0.1_n300 0.890612 0.896532 0.136989 0.050140 4.911353 0.136989 0.050140 4.911353 1 True 38
19 ExtraTrees_hs=0.1_n100 0.890299 0.893999 0.050237 0.020473 1.644582 0.050237 0.020473 1.644582 1 True 39
20 ExtraTrees_hs=50_n10 0.890179 0.895940 0.011973 0.007939 0.142516 0.011973 0.007939 0.142516 1 True 24
21 ExtraTrees_og_n300 0.890174 0.895924 0.124822 0.050993 0.383026 0.124822 0.050993 0.383026 1 True 2
22 ExtraTrees_hs=50_n20 0.890151 0.893078 0.015931 0.009639 0.273224 0.015931 0.009639 0.273224 1 True 23
23 ExtraTrees_hs=500_n1000 0.890060 0.888834 0.361120 0.149288 12.790663 0.361120 0.149288 12.790663 1 True 31
24 ExtraTrees_hs=500_n300 0.889899 0.889657 0.188049 0.047488 3.835120 0.188049 0.047488 3.835120 1 True 32
25 ExtraTrees_og_n100 0.889513 0.893440 0.049970 0.024512 0.153703 0.049970 0.024512 0.153703 1 True 3
26 ExtraTrees_hs=10_n20 0.889493 0.894657 0.016829 0.010440 0.279750 0.016829 0.010440 0.279750 1 True 17
27 ExtraTrees_hs=500_n100 0.888305 0.888999 0.090708 0.020563 1.297165 0.090708 0.020563 1.297165 1 True 33
28 ExtraTrees_hs=1_n40 0.888270 0.885972 0.025700 0.013705 0.542178 0.025700 0.013705 0.542178 1 True 10
29 ExtraTrees_hs=100_n10 0.888002 0.893604 0.011640 0.009796 0.155926 0.011640 0.009796 0.155926 1 True 30
30 ExtraTrees_hs=10_n10 0.887781 0.893868 0.012549 0.007775 0.151623 0.012549 0.007775 0.151623 1 True 18
31 ExtraTrees_hs=500_n40 0.887740 0.889031 0.024130 0.012207 0.537571 0.024130 0.012207 0.537571 1 True 34
32 ExtraTrees_hs=100_n20 0.887497 0.890413 0.015986 0.009240 0.286911 0.015986 0.009240 0.286911 1 True 29
33 ExtraTrees_hs=0.1_n40 0.885989 0.883340 0.024764 0.011959 0.671531 0.024764 0.011959 0.671531 1 True 40
34 ExtraTrees_og_n40 0.883635 0.881021 0.025266 0.014620 0.075054 0.025266 0.014620 0.075054 1 True 4
35 ExtraTrees_hs=1_n20 0.880430 0.889196 0.017302 0.009496 0.323042 0.017302 0.009496 0.323042 1 True 11
36 ExtraTrees_hs=0.1_n20 0.878517 0.888044 0.016657 0.009894 0.348929 0.016657 0.009894 0.348929 1 True 41
37 ExtraTrees_hs=500_n20 0.876458 0.881596 0.017727 0.009557 0.295510 0.017727 0.009557 0.295510 1 True 35
38 ExtraTrees_hs=500_n10 0.876298 0.883208 0.013577 0.007943 0.155019 0.013577 0.007943 0.155019 1 True 36
39 ExtraTrees_hs=1_n10 0.875740 0.881333 0.012778 0.008113 0.143919 0.012778 0.008113 0.143919 1 True 12
40 ExtraTrees_hs=0.1_n10 0.874861 0.880445 0.012299 0.007819 0.180453 0.012299 0.007819 0.180453 1 True 42
41 ExtraTrees_og_n20 0.872678 0.882090 0.016654 0.013582 0.046497 0.016654 0.013582 0.046497 1 True 5
42 ExtraTrees_og_n10 0.861819 0.862235 0.012075 0.013637 0.047344 0.012075 0.013637 0.047344 1 True 6