scikit-learn icon indicating copy to clipboard operation
scikit-learn copied to clipboard

NOCATS: Categorical splits for tree-based learners (ctnd.)

Open adrinjalali opened this issue 6 years ago • 56 comments

This PR continues the work of #4899. For now I've merged the master into the PR, made it compile and make the tests run. There are several issues which need to be fixed. The list will be updated as I encounter them. Also, not all of these items are necessarily open, I have only collected them from the comments on the original PR, and need to make sure they're either already addressed or address them.

  • merge master into the PR (done)
  • sparse tests pass (done)
    • The code is supposed to be the same as the status quo implementation if categories are not passed. But right now the tests related to sparse data fail.
    • EDIT: The tests pass if we compare floats with almost_equal
  • LabelEncoder -> CategoricalEncoder (done)
    • Preprocessing is not a part of NOCATS anymore.
  • Is maximum random generations 20 or 40 (done)
    • It's actually 60
  • Don't quantize features automatically (done)
    • Doesn't happen anymore: https://github.com/scikit-learn/scikit-learn/pull/4899#issuecomment-271504258
  • check the category count limits for given data. (done)
  • add a benchmark
    • done. Results: https://github.com/scikit-learn/scikit-learn/pull/12866#issuecomment-453856876
  • add tests (right now only invalid input are tested)
    • tree/tests done
    • ensemble/tests done
  • benchmark against master
  • add an example with plots
  • check numpy upgrade related issues (we've upgraded our numpy requirement in the meantime)
  • run some benchmarks with a simple integer coding of the features (with arbitrary ordering)
  • add cat_split to NODE_DTYPE once joblib.hash can handle it (padded struct)
    • joblib issue: joblib/joblib#826

Closes #4899

Future Work: These are the possible future work we already know of (i.e. outside the scope of this PR):

  • Heuristic methods to allow fast Breiman-like training for multi-class classification
  • export to graphviz
  • One-hot emulation using the NOCATS machinery
  • support sparse input
  • handle categories as their unique valies instead of [0, max(feature)]
    • This is to be consistent with our encoders' behavior
    • moved this to future work per https://github.com/scikit-learn/scikit-learn/pull/12866#issuecomment-455021204

P.S. I moved away from "task list" due to the extremely buggy interface when used in combination with editing the post, which I'm extensively doing to keep it easy for us to keep up with the status.

adrinjalali avatar Dec 26 '18 12:12 adrinjalali

Wow. Good on you for taking this on!

jnothman avatar Dec 26 '18 12:12 jnothman

I̶ ̶a̶s̶s̶u̶m̶e̶ ̶t̶h̶e̶ ̶a̶p̶p̶v̶e̶y̶o̶r̶ ̶f̶a̶i̶l̶u̶r̶e̶ ̶i̶s̶ ̶u̶n̶r̶e̶l̶a̶t̶e̶d̶ ̶t̶o̶ ̶t̶h̶i̶s̶ ̶P̶R̶ ̶I̶ ̶s̶u̶p̶p̶o̶s̶e̶.̶

adrinjalali avatar Dec 26 '18 14:12 adrinjalali

Just realized there are no tests, testing the categorical feature, and it's gone undetected cause codecov doesn't cover cython code I guess.

UPDATE: It can, and #12872 should do it, plus the directives added in this PR for the tree code.

adrinjalali avatar Dec 27 '18 09:12 adrinjalali

Question: to test NOCATS, should I

  • synthesize input/output in the tests.
  • use a dataset from openml which has some categorical data (like the amazon employee one)
  • copy a part of such dataset somewhere in the repo and use that
  • add an sklearn.datasets.make_categorical or something and use that.

adrinjalali avatar Jan 12 '19 16:01 adrinjalali

Results of the benchmark on amazon employee dataset and a 5 StratifiedKFold:

                                                  train time mean  train time std  test time mean  test time std  auc mean   auc std
name                                                                                                                                
(RandomForestClassifier, truncated(2), NOCATS)           0.011249        0.000308        0.000964       0.000037  0.550338  0.066464
(ExtraTreesClassifier, truncated(2), NOCATS)             0.016150        0.001321        0.001583       0.000493  0.540433  0.087579
(RandomForestClassifier, truncated(2), One-hot)          0.021381        0.002790        0.002191       0.000651  0.534658  0.091971
(ExtraTreesClassifier, truncated(2), One-hot)            0.026961        0.001253        0.002737       0.000229  0.553065  0.073866
(RandomForestClassifier, truncated(3), NOCATS)           0.093635        0.014216        0.003194       0.000236  0.527636  0.014012
(ExtraTreesClassifier, truncated(3), NOCATS)             0.086353        0.004150        0.003881       0.000571  0.596458  0.028335
(RandomForestClassifier, truncated(3), One-hot)          0.284854        0.047508        0.004015       0.000867  0.530455  0.015537
(ExtraTreesClassifier, truncated(3), One-hot)            0.274569        0.026760        0.002963       0.000131  0.530211  0.016491
(RandomForestClassifier, truncated(4), NOCATS)           0.267862        0.051063        0.006633       0.001659  0.582438  0.013270
(ExtraTreesClassifier, truncated(4), NOCATS)             0.233336        0.028149        0.009689       0.002523  0.694099  0.025380
(RandomForestClassifier, truncated(4), One-hot)          1.098896        0.128780        0.005357       0.000866  0.590333  0.015817
(ExtraTreesClassifier, truncated(4), One-hot)            1.098966        0.036243        0.006173       0.001389  0.590816  0.014309
(RandomForestClassifier, truncated(5), NOCATS)           0.450167        0.034263        0.006791       0.000124  0.624847  0.011485
(ExtraTreesClassifier, truncated(5), NOCATS)             0.262979        0.038994        0.009673       0.002543  0.730078  0.012169
(RandomForestClassifier, truncated(5), One-hot)          1.348990        0.257518        0.006812       0.002311  0.625683  0.020087
(ExtraTreesClassifier, truncated(5), One-hot)            1.305170        0.180154        0.006293       0.001040  0.624495  0.019361
(RandomForestClassifier, truncated(6), NOCATS)           0.660743        0.047563        0.008925       0.001506  0.598595  0.008608
(ExtraTreesClassifier, truncated(6), NOCATS)             0.281398        0.019996        0.010045       0.001925  0.730779  0.018518
(RandomForestClassifier, truncated(6), One-hot)          2.069925        0.070151        0.006579       0.000087  0.624124  0.014505
(ExtraTreesClassifier, truncated(6), One-hot)            2.141421        0.161909        0.007982       0.001005  0.626265  0.013016
(RandomForestClassifier, truncated(8), NOCATS)           1.678582        0.172174        0.010039       0.001731  0.617643  0.010375
(ExtraTreesClassifier, truncated(8), NOCATS)             0.457223        0.082331        0.016995       0.002914  0.764477  0.009395
(RandomForestClassifier, truncated(8), One-hot)          3.219421        0.206134        0.010428       0.002274  0.663004  0.011065
(ExtraTreesClassifier, truncated(8), One-hot)            3.244661        0.268812        0.010613       0.002959  0.662469  0.010370
(RandomForestClassifier, truncated(10), NOCATS)          2.913669        0.683039        0.011459       0.002094  0.627301  0.006148
(ExtraTreesClassifier, truncated(10), NOCATS)            0.409481        0.040790        0.012915       0.001220  0.790844  0.009100
(RandomForestClassifier, truncated(10), One-hot)         3.385529        0.143956        0.009309       0.001476  0.665676  0.020495
(ExtraTreesClassifier, truncated(10), One-hot)           3.557266        0.202766        0.010070       0.002159  0.663913  0.021057
(RandomForestClassifier, truncated(12), NOCATS)          6.791204        1.000485        0.010759       0.001237  0.633280  0.022892
(ExtraTreesClassifier, truncated(12), NOCATS)            0.430389        0.048933        0.014804       0.001880  0.798362  0.022032
(RandomForestClassifier, truncated(12), One-hot)         3.590943        0.149930        0.008884       0.000254  0.693790  0.014326
(ExtraTreesClassifier, truncated(12), One-hot)           3.758259        0.384995        0.008873       0.000208  0.695171  0.014020
(RandomForestClassifier, truncated(14), NOCATS)          8.663956        6.493043        0.009320       0.000107  0.641795  0.017540
(ExtraTreesClassifier, truncated(14), NOCATS)            0.395362        0.003310        0.012943       0.000085  0.792042  0.005735
(RandomForestClassifier, truncated(14), One-hot)         3.971470        0.077007        0.009382       0.000323  0.668610  0.006327
(ExtraTreesClassifier, truncated(14), One-hot)           3.999125        0.089887        0.009440       0.000541  0.667159  0.007891
(RandomForestClassifier, truncated(16), NOCATS)         65.699274       21.271701        0.009849       0.000543  0.629030  0.015845
(ExtraTreesClassifier, truncated(16), NOCATS)            0.495190        0.062269        0.015646       0.001640  0.799396  0.012293
(RandomForestClassifier, truncated(16), One-hot)         4.472251        0.118174        0.010816       0.001675  0.685805  0.007541
(ExtraTreesClassifier, truncated(16), One-hot)           4.247601        0.078162        0.010696       0.001912  0.684779  0.006410
(ExtraTreesClassifier, truncated(64), NOCATS)            0.513263        0.002021        0.019326       0.000150  0.823812  0.015814
(RandomForestClassifier, truncated(64), One-hot)         5.698007        0.260483        0.016509       0.002000  0.707647  0.009821
(ExtraTreesClassifier, truncated(64), One-hot)           5.734862        0.232589        0.015705       0.000615  0.706847  0.008810
(ExtraTreesClassifier, full, NOCATS)                     1.946335        0.007195        0.369055       0.003971  0.827463  0.004715
(RandomForestClassifier, full, One-hot)                 29.707810        1.943698        0.102112       0.017712  0.731165  0.004592
(ExtraTreesClassifier, full, One-hot)                   29.560783        1.966576        0.094439       0.011784  0.730904  0.004287

Conclusions:

  • RandomForest + NOCATS becomes intractable pretty quickly with too many categories, therefore >16 categories is not even present in the benchmark for them.
  • ExtraTreeClassifier + NOCATS outperforms all other cases in both time and performance.

adrinjalali avatar Jan 13 '19 19:01 adrinjalali

handle categories as their unique valies instead of [0, max(feature)]. This is to be consistent with our encoders' behavior

I'm not sure this is necessary as at any point in the tree there may be categories with no variance at that node. We can consider the encoding of string-valued/mixed arrays a bit later...

jnothman avatar Jan 17 '19 02:01 jnothman

I think synthesising data is fine for tests. An example using openml would be good

jnothman avatar Jan 17 '19 22:01 jnothman

@adrinjalali do you need an in-depth review at this point? If yes I'd be happy to try.

Also, for core devs: let's say this PR is reviewed and accepted with +2, does it have any chance to merged quickly or are we still in need of a general consensus regarding categorical variable support? IMHO this should be one of the top priorities.

NicolasHug avatar Jan 31 '19 16:01 NicolasHug

I think there is consensus to support categoricals. It needs benchmarking, and we need to be happy with our design choices.

jnothman avatar Jan 31 '19 22:01 jnothman

@adrinjalali do you need an in-depth review at this point? If yes I'd be happy to try.

I'm still reviewing the code, in the sense that I don't want tree/*.pyx to be less maintainable that what they are now, but there are some tricky parts which I'm not sure how to fix yet. I'm happy to share some of those issues and get help if you're willing and happy to put some time here @NicolasHug .

and we need to be happy with our design choices.

One issue left regarding design choices is that we can't expose the categorical splits of the tree to the user before https://github.com/joblib/joblib/issues/826 is fixed.

adrinjalali avatar Feb 01 '19 09:02 adrinjalali

Thanks both for the feedback.

@adrinjalali sure, LMK if there's anything I can help with!

NicolasHug avatar Feb 01 '19 13:02 NicolasHug

Question: I introduced a BitSet in this commit: https://github.com/scikit-learn/scikit-learn/pull/12866/commits/532061ff53db321cabcaa213e520a055ecec36df, but the object is a Python one and it's not easy to have an array or an array of arrays of it.

My efforts so far as gotten me to a stackoverflow question of mine, this disappointing question, and a solution using boost.

I guess some options are:

  • cover what we exactly need in a separate header/cpp file and use it in our cython code
  • use cpp's stl::vector but use a custom bitset written in cpp
  • other cython woodoo which I'm not aware of

@NicolasHug do you happen to have a good answer to this one?

adrinjalali avatar Feb 04 '19 14:02 adrinjalali

It seems that the BitSet class is just a wrapper on a uint64 right? And you dont' need to use it inside Python? In this case I would just directly create arrays of uint64 and translate the methods into pure cdef functions. Pretty much like what you would do if you were writing pure C. You can make a typeded from uint64 to bitset if you want to be more explicit.

I wanted to declare arrays of cdef classes as well, but it doesn't seem to be possible. In #12807 I have this class SplitInfo that I need to use from Python, and I had to create a split_info_struct that has the same attributes. I can create arrays of split_info_struct in C-mode, and when I need to manipulate such an object in Python I just wrap it into the class. It works in my case because:

  • I don't need arrays in Python, just single objects
  • It's a pure data class (no method). If there was any method I think I should have to duplicate them: methods for the class, and equivalent functions for C. Which would be pretty annoying.

Hope that helps!


Unrelated but might useful: I've found that some weird stuff happens when using 1d slices of 2d arrays. For example having a function cdef f(int [:] 1d_slice): ... called with f(some_2d_array[index, :]) will generate strange Python interactions (it's related to the GIL and probably also to the use of prange so that might not affect you). My work-around was to make f signature as cdef f(int [:, :] 2d_array, const unsigned int index): ... . Looking at the annotated html files (cython -a) will show you the Python interactions and sometimes they appear in unexpected places!

NicolasHug avatar Feb 04 '19 15:02 NicolasHug

sprint discussion conclusions:

The new implementation in #12807 and the fact that it makes sense to have NOCATS there, deprioritizes this PR.

The _splitter.pyx can be very much simplified, and presort can be removed from that code, since it's only used for gbc, and the new gbc is much faster anyway.

adrinjalali avatar Feb 25 '19 17:02 adrinjalali

Any update on this? :)

Also is there an issue to track categorical features for the new (currently still experimental) implementation? (It's currently a bit hard to find out what the status of categorical features is, and (possibly) discuss how it should be tackled, now that the implementation changed. Readers of #9960 are redirected to this PR here.)

h-vetinari avatar Jun 17 '19 15:06 h-vetinari

@h-vetinari we'll soon work on NOCATS for HGB models, it'll hopefully be there by next release.

adrinjalali avatar Jun 17 '19 15:06 adrinjalali

@adrinjalali considering that we decided not to implement categorical support in the tree module, I think we can close this one and #4899 ?

NicolasHug avatar Aug 07 '19 19:08 NicolasHug

We decided not to prioritize this one. I'm still planning to finish this, but first I want to clean up and simplify the splitter code, which we said we could do once the HGBT is released, which it is now. I just haven't go to it. This also implements split policies which we probably won't have in HGBT (random splitter to be specific), which work really well with extra trees.

adrinjalali avatar Aug 08 '19 07:08 adrinjalali

This paper is interesting: https://peerj.com/articles/6339/

Unfortunately it doesn't consider high cardinality categorical variables, and it considers only a small set of datasets. But it shows that actually ordering categories once before building the trees might be a good strategy. That's interesting because that's waaay easier to implement ;) Even if we also implement the NOCATS approach, we could provide an estimator that does the "once and for all" ordering for regression and binary classification. The authors also provide a similar heuristic for multi-class, which also sounds interesting.

amueller avatar Nov 22 '19 22:11 amueller

Even if we also implement the NOCATS approach, we could provide an estimator that does the "once and for all" ordering for regression and binary classification.

What we plan to implement in the Histogram-GBDTs is what the paper does, except that we sort categories at each split instead of at the beginning (paper also discuss that)

CC @adrinjalali

NicolasHug avatar Nov 26 '19 13:11 NicolasHug

Also I'm super late to the party, but what is the benefit of NOCATs over One-Hot-Encoding the categories? As far as I understand the strategy proposed here is equivalent to re-implementing the OHE within the tree logic. So what are the main benefits of NOCATs over OHE, apart from using less memory?

NicolasHug avatar Nov 26 '19 13:11 NicolasHug

Also I'm super late to the party, but what is the benefit of NOCATs over One-Hot-Encoding the categories?

One-hot encoding only allows you to split off 1-vs-the-rest, whereas the optimal split for a categorical variable may be many-vs-many. For example, the optimal split at a given node may be:

{A, B, C, D, E, F, G} --> {B, C, F} vs. {A, D, E, G}

but one-hot encoding would only be able to yield one of

{A, B, C, D, E, F, G} --> {A} vs. {B, C, D, E, F, G}
{A, B, C, D, E, F, G} --> {B} vs. {A, C, D, E, F, G}
{A, B, C, D, E, F, G} --> {C} vs. {A, B, D, E, F, G}
{A, B, C, D, E, F, G} --> {D} vs. {A, B, C, E, F, G}
{A, B, C, D, E, F, G} --> {E} vs. {A, B, C, D, F, G}
{A, B, C, D, E, F, G} --> {F} vs. {A, B, C, D, E, G}
{A, B, C, D, E, F, G} --> {G} vs. {A, B, C, D, E, F}

This obviously affects the depth / number of splits that are necessary to get a similarly good result.

h-vetinari avatar Nov 26 '19 14:11 h-vetinari

@NicolasHug this is only one benchmark, but at least on this dataset, there are benefits to using NOCATS: https://github.com/scikit-learn/scikit-learn/pull/12866#issuecomment-453856876

adrinjalali avatar Nov 26 '19 14:11 adrinjalali

@NicolasHug

What we plan to implement in the Histogram-GBDTs is what the paper does, except that we sort categories at each split instead of at the beginning (paper also discuss that)

That is the exact solution for regression and some binary cases. This PR mentions in the beginning "Heuristic methods to allow fast Breiman-like training for multi-class classification" which is basically what you're implementing.

Thought I imagine you're doing this based on the unnormalized probabilities, which is one more level of indirection compared with the trees.

amueller avatar Dec 02 '19 22:12 amueller

actually, because you're doing a regression tree each time, the sorting my always be exact, depending on the loss. I need to think about that again and look at the formula.

amueller avatar Dec 02 '19 23:12 amueller

@adrinjalali Is there a chance you'll pick this up again in the foreseeable future? It's still a sore need in many situations, and forcing those affected to other packages than scikit-learn.

h-vetinari avatar Mar 09 '20 08:03 h-vetinari

We probably are going to implement this (or a version of it) for the new HistGradientBoosting ones instead. Realistically, we should have it in 6 months or so.

adrinjalali avatar Mar 09 '20 09:03 adrinjalali

Great to hear, thanks!

h-vetinari avatar Mar 09 '20 09:03 h-vetinari

@adrinjalali Any idea of an ETA on this? Just planning a few projects and this feature would be great to have!

dlong11 avatar Mar 26 '20 01:03 dlong11

we're planning to have a version of this, probably for HistGradientBoosting* in the October/November release.

adrinjalali avatar Mar 26 '20 10:03 adrinjalali