imodels icon indicating copy to clipboard operation
imodels copied to clipboard

Speeding up HS with LOOCV

Open Innixma opened this issue 1 year ago • 5 comments

Hello, thanks again for the great library!

I'm interested in applying HSTree to RandomForest and ExtraTrees models.

According to the documentation, I can specify a random forest object in the estimator_ argument, however this raises an error when I try to fit:

        from sklearn.ensemble import RandomForestClassifier
        model = RandomForestClassifier()
        import imodels
        from imodels.tree.hierarchical_shrinkage import HSTreeClassifier
        model = HSTreeClassifier(estimator_=model)
        model = model.fit(X, y)
  File "/Users/neerick/workspace/code/autogluon/tabular/src/autogluon/tabular/models/rf/rf_model.py", line 232, in _fit
    model = model.fit(X, y, sample_weight=sample_weight)
  File "/Users/neerick/workspace/virtual/autogluon/lib/python3.8/site-packages/imodels/tree/hierarchical_shrinkage.py", line 64, in fit
    self.complexity_ = compute_tree_complexity(self.estimator_.tree_)
AttributeError: 'RandomForestClassifier' object has no attribute 'tree_'

https://github.com/csinva/imodels/blob/master/imodels/tree/hierarchical_shrinkage.py

I don't see any tutorial / documentation for creating a random forest or extra trees model via HSTree, but the paper mentions that this is possible and gets good results. I was wondering if the maintainers could point me to an example or tutorial on this.

Thanks!

Innixma avatar Jul 27 '22 21:07 Innixma

Thanks for the question Nick! Indeed we should have publicly released this with some documentation by now - will get to it shortly!

@aagarwal1996 - can you add in the HS + RF / ExtraTrees code with a little doc?

Best, Chandan

csinva avatar Jul 28 '22 20:07 csinva

Just fixed it in this commit :)

Will have to bump the version to 1.3.3 get it to work, but the code you were using above should work now. Also added a little snippet to the doc here.

Cheers, Chandan

csinva avatar Jul 28 '22 21:07 csinva

Incredible. It works now and is showing very strong results on the Adult dataset (will plan to test more). For example, num_estimators=40 with HSTreeClassifierCV is getting significantly better test scores than RandomForestClassifier(num_estimators=300). The API is perfect, and I can add it in AutoGluon with <10 lines of code.

Question I had related to the ICML oral presentation / paper, I recall the mention of efficient leave-one-out CV to optimize the reg_param. Is this implemented or do I need to pay the cost of HSTreeClassifierCV to get the optimal reg_param value?

Innixma avatar Jul 29 '22 02:07 Innixma

Glad to hear it's working for you 😄

This isn't currently implemented but should be easy to do - we'll get it done soon, especially if you're putting this into AutoGluon!

@aagarwal1996 @yanshuotan @OmerRonen - someone wanna take this on?

csinva avatar Jul 29 '22 02:07 csinva

Here is a hyperparameter grid result on Adult sampled to 2000 rows of training data, sorted by test score (AUC)

og = vanilla sklearn hs=10_n300 = reg_param=10,n_estimators=300

Amazingly, HS with reg_param 50, 100, and 500 and only 10 n_estimators get a better test score than OG with 1000 n_estimators!

                        model  score_test  score_val  pred_time_test  pred_time_val   fit_time  pred_time_test_marginal  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0     RandomForest_hs=50_n300    0.906977   0.904691        0.095684       0.050235   2.733138                 0.095684                0.050235           2.733138            1       True         20
1     RandomForest_hs=50_n100    0.906954   0.905086        0.037076       0.087729   0.922209                 0.037076                0.087729           0.922209            1       True         21
2    RandomForest_hs=50_n1000    0.906939   0.905020        0.287926       0.150001   9.144297                 0.287926                0.150001           9.144297            1       True         19
3    RandomForest_hs=100_n300    0.906854   0.905185        0.098557       0.050372   2.767945                 0.098557                0.050372           2.767945            1       True         26
4   RandomForest_hs=100_n1000    0.906818   0.905020        0.297747       0.155764   9.122136                 0.297747                0.155764           9.122136            1       True         25
5    RandomForest_hs=100_n100    0.906746   0.904823        0.040018       0.022333   0.951915                 0.040018                0.022333           0.951915            1       True         27
6         WeightedEnsemble_L2    0.906464   0.906106        0.822571       0.493818  25.186706                 0.003169                0.000933           0.448992            2       True         43
7      RandomForest_hs=50_n40    0.905625   0.903869        0.019060       0.012893   0.386387                 0.019060                0.012893           0.386387            1       True         22
8     RandomForest_hs=100_n40    0.905619   0.904691        0.019219       0.015171   0.412162                 0.019219                0.015171           0.412162            1       True         28
9     RandomForest_hs=10_n100    0.904859   0.904955        0.038016       0.021933   0.915743                 0.038016                0.021933           0.915743            1       True         15
10    RandomForest_hs=10_n300    0.904659   0.905481        0.098096       0.049020   2.781127                 0.098096                0.049020           2.781127            1       True         14
11   RandomForest_hs=10_n1000    0.904594   0.904659        0.296132       0.150198   9.204939                 0.296132                0.150198           9.204939            1       True         13
12    RandomForest_hs=100_n20    0.904273   0.901533        0.013146       0.013468   0.245283                 0.013146                0.013468           0.245283            1       True         29
13     RandomForest_hs=50_n20    0.903543   0.901336        0.012793       0.009559   0.201715                 0.012793                0.009559           0.201715            1       True         23
14  RandomForest_hs=500_n1000    0.903443   0.903310        0.292724       0.153926   9.449500                 0.292724                0.153926           9.449500            1       True         31
15   RandomForest_hs=500_n300    0.903273   0.904033        0.098899       0.058785   2.789253                 0.098899                0.058785           2.789253            1       True         32
16     RandomForest_hs=10_n40    0.902558   0.899888        0.020626       0.015472   0.387508                 0.020626                0.015472           0.387508            1       True         16
17   RandomForest_hs=500_n100    0.902558   0.904066        0.055166       0.022750   0.920178                 0.055166                0.022750           0.920178            1       True         33
18    RandomForest_hs=500_n40    0.901187   0.901928        0.038185       0.013044   0.375549                 0.038185                0.013044           0.375549            1       True         34
19    RandomForest_hs=500_n20    0.901066   0.899888        0.019732       0.012050   0.197903                 0.019732                0.012050           0.197903            1       True         35
20    RandomForest_hs=100_n10    0.901057   0.897684        0.010498       0.008219   0.132994                 0.010498                0.008219           0.132994            1       True         30
21     RandomForest_hs=50_n10    0.900010   0.896697        0.010349       0.007744   0.106932                 0.010349                0.007744           0.106932            1       True         24
22     RandomForest_hs=1_n100    0.899934   0.903639        0.038109       0.021423   0.925629                 0.038109                0.021423           0.925629            1       True          9
23    RandomForest_hs=1_n1000    0.899916   0.904823        0.292249       0.153203   9.213101                 0.292249                0.153203           9.213101            1       True          7
24    RandomForest_hs=500_n10    0.899908   0.899296        0.011997       0.008104   0.118062                 0.011997                0.008104           0.118062            1       True         36
25     RandomForest_hs=1_n300    0.899902   0.903704        0.097325       0.054082   2.725383                 0.097325                0.054082           2.725383            1       True          8
26     RandomForest_hs=10_n20    0.898988   0.899197        0.017777       0.009672   0.217858                 0.017777                0.009672           0.217858            1       True         17
27   RandomForest_hs=0.1_n100    0.898405   0.901961        0.038886       0.019792   1.188789                 0.038886                0.019792           1.188789            1       True         39
28  RandomForest_hs=0.1_n1000    0.898272   0.903869        0.300892       0.151734  11.785147                 0.300892                0.151734          11.785147            1       True         37
29   RandomForest_hs=0.1_n300    0.898254   0.903211        0.102140       0.059019   3.490166                 0.102140                0.059019           3.490166            1       True         38
30      RandomForest_og_n1000    0.898019   0.903721        0.300054       0.163219   2.121544                 0.300054                0.163219           2.121544            1       True          1
31       RandomForest_og_n300    0.897952   0.903145        0.098421       0.053610   0.375765                 0.098421                0.053610           0.375765            1       True          2
32       RandomForest_og_n100    0.897658   0.901533        0.044796       0.021702   0.158356                 0.044796                0.021702           0.158356            1       True          3
33      RandomForest_hs=1_n40    0.895661   0.894657        0.019914       0.014062   0.378873                 0.019914                0.014062           0.378873            1       True         10
34     RandomForest_hs=10_n10    0.894480   0.893670        0.011885       0.008050   0.107222                 0.011885                0.008050           0.107222            1       True         18
35    RandomForest_hs=0.1_n40    0.894129   0.893604        0.018879       0.011820   0.469591                 0.018879                0.011820           0.469591            1       True         40
36        RandomForest_og_n40    0.891790   0.891910        0.018995       0.014379   0.073364                 0.018995                0.014379           0.073364            1       True          4
37      RandomForest_hs=1_n20    0.890021   0.896532        0.013560       0.010845   0.239654                 0.013560                0.010845           0.239654            1       True         11
38    RandomForest_hs=0.1_n20    0.888886   0.896105        0.013955       0.010772   0.243502                 0.013955                0.010772           0.243502            1       True         41
39        RandomForest_og_n20    0.883450   0.890907        0.013461       0.010590   0.045405                 0.013461                0.010590           0.045405            1       True          5
40      RandomForest_hs=1_n10    0.883443   0.888012        0.010171       0.008007   0.111472                 0.010171                0.008007           0.111472            1       True         12
41    RandomForest_hs=0.1_n10    0.882777   0.887584        0.009979       0.007687   0.138142                 0.009979                0.007687           0.138142            1       True         42
42        RandomForest_og_n10    0.869862   0.874079        0.010571       0.008906   0.030527                 0.010571                0.008906           0.030527            1       True          6

Similar findings with Extra Trees!

                      model  score_test  score_val  pred_time_test  pred_time_val   fit_time  pred_time_test_marginal  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0    ExtraTrees_hs=50_n1000    0.898862   0.895578        0.360272       0.152184  12.766335                 0.360272                0.152184          12.766335            1       True         19
1    ExtraTrees_hs=10_n1000    0.898727   0.898638        0.375004       0.154642  13.135831                 0.375004                0.154642          13.135831            1       True         13
2     ExtraTrees_hs=50_n300    0.898286   0.896006        0.126808       0.057546   3.828264                 0.126808                0.057546           3.828264            1       True         20
3     ExtraTrees_hs=50_n100    0.898091   0.895545        0.050484       0.020662   1.293077                 0.050484                0.020662           1.293077            1       True         21
4     ExtraTrees_hs=10_n300    0.897934   0.898802        0.129800       0.050548   3.923364                 0.129800                0.050548           3.923364            1       True         14
5     ExtraTrees_hs=10_n100    0.897840   0.897848        0.050855       0.021138   1.308959                 0.050855                0.021138           1.308959            1       True         15
6   ExtraTrees_hs=100_n1000    0.897263   0.894295        0.363426       0.148167  12.866699                 0.363426                0.148167          12.866699            1       True         25
7    ExtraTrees_hs=100_n300    0.896848   0.894756        0.129441       0.049175   3.806397                 0.129441                0.049175           3.806397            1       True         26
8      ExtraTrees_hs=50_n40    0.896566   0.893868        0.024298       0.012653   0.611947                 0.024298                0.012653           0.611947            1       True         22
9       WeightedEnsemble_L2    0.896430   0.900053        1.370685       0.559450  50.761817                 0.003386                0.000914           0.446805            2       True         43
10   ExtraTrees_hs=100_n100    0.896420   0.893868        0.046713       0.026253   1.349054                 0.046713                0.026253           1.349054            1       True         27
11     ExtraTrees_hs=10_n40    0.895319   0.892486        0.025234       0.012230   0.530321                 0.025234                0.012230           0.530321            1       True         16
12    ExtraTrees_hs=100_n40    0.895164   0.893341        0.022953       0.013198   0.523787                 0.022953                0.013198           0.523787            1       True         28
13    ExtraTrees_hs=1_n1000    0.894024   0.898441        0.351593       0.148842  13.060638                 0.351593                0.148842          13.060638            1       True          7
14     ExtraTrees_hs=1_n300    0.892950   0.897848        0.125843       0.050800   3.828052                 0.125843                0.050800           3.828052            1       True          8
15     ExtraTrees_hs=1_n100    0.892657   0.895743        0.052110       0.021437   1.310609                 0.052110                0.021437           1.310609            1       True          9
16  ExtraTrees_hs=0.1_n1000    0.891799   0.897914        0.385060       0.153704  16.367127                 0.385060                0.153704          16.367127            1       True         37
17      ExtraTrees_og_n1000    0.891447   0.897816        0.372907       0.220728   2.279349                 0.372907                0.220728           2.279349            1       True          1
18   ExtraTrees_hs=0.1_n300    0.890612   0.896532        0.136989       0.050140   4.911353                 0.136989                0.050140           4.911353            1       True         38
19   ExtraTrees_hs=0.1_n100    0.890299   0.893999        0.050237       0.020473   1.644582                 0.050237                0.020473           1.644582            1       True         39
20     ExtraTrees_hs=50_n10    0.890179   0.895940        0.011973       0.007939   0.142516                 0.011973                0.007939           0.142516            1       True         24
21       ExtraTrees_og_n300    0.890174   0.895924        0.124822       0.050993   0.383026                 0.124822                0.050993           0.383026            1       True          2
22     ExtraTrees_hs=50_n20    0.890151   0.893078        0.015931       0.009639   0.273224                 0.015931                0.009639           0.273224            1       True         23
23  ExtraTrees_hs=500_n1000    0.890060   0.888834        0.361120       0.149288  12.790663                 0.361120                0.149288          12.790663            1       True         31
24   ExtraTrees_hs=500_n300    0.889899   0.889657        0.188049       0.047488   3.835120                 0.188049                0.047488           3.835120            1       True         32
25       ExtraTrees_og_n100    0.889513   0.893440        0.049970       0.024512   0.153703                 0.049970                0.024512           0.153703            1       True          3
26     ExtraTrees_hs=10_n20    0.889493   0.894657        0.016829       0.010440   0.279750                 0.016829                0.010440           0.279750            1       True         17
27   ExtraTrees_hs=500_n100    0.888305   0.888999        0.090708       0.020563   1.297165                 0.090708                0.020563           1.297165            1       True         33
28      ExtraTrees_hs=1_n40    0.888270   0.885972        0.025700       0.013705   0.542178                 0.025700                0.013705           0.542178            1       True         10
29    ExtraTrees_hs=100_n10    0.888002   0.893604        0.011640       0.009796   0.155926                 0.011640                0.009796           0.155926            1       True         30
30     ExtraTrees_hs=10_n10    0.887781   0.893868        0.012549       0.007775   0.151623                 0.012549                0.007775           0.151623            1       True         18
31    ExtraTrees_hs=500_n40    0.887740   0.889031        0.024130       0.012207   0.537571                 0.024130                0.012207           0.537571            1       True         34
32    ExtraTrees_hs=100_n20    0.887497   0.890413        0.015986       0.009240   0.286911                 0.015986                0.009240           0.286911            1       True         29
33    ExtraTrees_hs=0.1_n40    0.885989   0.883340        0.024764       0.011959   0.671531                 0.024764                0.011959           0.671531            1       True         40
34        ExtraTrees_og_n40    0.883635   0.881021        0.025266       0.014620   0.075054                 0.025266                0.014620           0.075054            1       True          4
35      ExtraTrees_hs=1_n20    0.880430   0.889196        0.017302       0.009496   0.323042                 0.017302                0.009496           0.323042            1       True         11
36    ExtraTrees_hs=0.1_n20    0.878517   0.888044        0.016657       0.009894   0.348929                 0.016657                0.009894           0.348929            1       True         41
37    ExtraTrees_hs=500_n20    0.876458   0.881596        0.017727       0.009557   0.295510                 0.017727                0.009557           0.295510            1       True         35
38    ExtraTrees_hs=500_n10    0.876298   0.883208        0.013577       0.007943   0.155019                 0.013577                0.007943           0.155019            1       True         36
39      ExtraTrees_hs=1_n10    0.875740   0.881333        0.012778       0.008113   0.143919                 0.012778                0.008113           0.143919            1       True         12
40    ExtraTrees_hs=0.1_n10    0.874861   0.880445        0.012299       0.007819   0.180453                 0.012299                0.007819           0.180453            1       True         42
41        ExtraTrees_og_n20    0.872678   0.882090        0.016654       0.013582   0.046497                 0.016654                0.013582           0.046497            1       True          5
42        ExtraTrees_og_n10    0.861819   0.862235        0.012075       0.013637   0.047344                 0.012075                0.013637           0.047344            1       True          6

Innixma avatar Jul 29 '22 03:07 Innixma