scikit-learn
scikit-learn copied to clipboard
[Refactor] Modularize the tree class in both Python and Cython to enable easy extensions
Describe the workflow you want to enable
As we are waiting for a reviewer to review #22754 , @thomasjpfan suggested we just move forward with our goals of creating a package of more exotic tree splits. E.g. https://arxiv.org/abs/1909.11799. While we wait for reviewers, the suggestion is to make a package within scikit-learn-contrib.
Although we would like for #22754 to be eventually merged into scikit-learn, we understand reviewer backlog is an issue. To move forward while reviews occur, we would need to subclass existing scikit-learn code. Ideally, we would like to introduce minor refactoring changes that would make this task significantly easier.
We would like to subclass directly from scikit-learn without requiring us to keep an up-to-date fork of scikit-learn with all the bug fixes and maintenance that the dev team here does. We can limit this if we modularize the Python/Cython functions inside the sklearn/tree module.
Describe your proposed solution
I am proposing two refactoring modifications that have no impact on the performance of the current tree estimators in scikit-learn.
- Refactor the
BaseDecisionTreePython class to have the following functions that can be overridden in a subclass:
_set_tree_class: sets the Tree Cython class that the Python API uses_set_splitter: sets the Splitter Cython class that the Python API uses
For example, this makes the subclassing of BaseDecisionTree cleaner: https://github.com/scikit-learn/scikit-learn/blob/de06afa78458d773d7c0482ff17a55a235d0886b/sklearn/tree/_classes.py#L410-L416.
- Refactor the
TreeCython class to have the following functions:
_set_node_values: transfers split node values to the storage node_compute_feature_value: uses the storage node and the input data to compute the feature value to split on
For example, see https://github.com/scikit-learn/scikit-learn/blob/de06afa78458d773d7c0482ff17a55a235d0886b/sklearn/tree/_tree.pyx#L770-L787.
- Refactor the
TreeBuilderCython class to pass around aSplitpointer, rather than the struct itself
This will enable someone to use C-level functions to pass around another struct with a similar structure as Split.
For example, see https://github.com/scikit-learn/scikit-learn/blob/de06afa78458d773d7c0482ff17a55a235d0886b/sklearn/tree/_tree.pyx#L193.
Describe alternatives you've considered, if relevant
Alternatives would require maintaining a copy of the sklearn/tree module and keep it up-to-date w/ sklearn changes. If this was just one Cython file, I would say it is possible, but the necessary ingredients span some of the underlying private API, making this a very time-consuming task. Introducing modularity into the private API that does not impact existing performance, therefore seems to be the best path forward?
Moreover, by introducing these refactoring changes, #22754 has a smaller diff and lower-cost to review.
Additional context
#22754 demonstrates that there is no performance regression, or issues w/ existing DecisionTree, or RandomForest when introducing these changes.
Note that since _set_tree_class and other extension points will be private in scikit-learn. This means that they can change without deprecation between scikit-learn versions. A third party library that subclasses these classes may have more maintenance overhead to support different scikit-learn versions.
I know I suggested it, but item 3 with the pointer_size is kind of hacky. I am okay with it being used internally, but I am uneasy about it as a third-party extension point. (I'm around +0.3)
Stepping back, the main issue is that the trees were designed around axis-aligned splits. For third party libraries that only need axis-aligned splits, it is easy to subclass. For scikit-learn, do we want to add these extension points to allow for non axis-aligned splits, when they are not in the library itself? I am slightly in favor at +0.5.
Thanks Thomas! Adding some comments to augment the discussion. The high level goal of all 3 I would say is to enable more flexibility wrt Python/Cython code to allow maximal reusability.
I know I suggested it, but item 3 with the
pointer_sizeis kind of hacky. I am okay with it being used internally, but I am uneasy about it as a third-party extension point. (I'm around +0.3)
After you showed me this trick, I perused more and this actually seems to be an "accepted" standard for "inheritance" of structs in C. Unfortunately, Cython takes after C in this regard for structs and not C++. Perhaps some proper documentation inline code to reduce technical confusion of anyone in the future reading the Tree code, so that way they don't need to have extensive C knowledge.
Stepping back, the main issue is that the trees were designed around axis-aligned splits. For third party libraries that only need axis-aligned splits, it is easy to subclass. For scikit-learn, do we want to add these extension points to allow for non axis-aligned splits, when they are not in the library itself? I am slightly in favor at +0.5.
I somewhat agree with the point here. The extension points enable someone to subclass the existing code more flexibly I would argue rather than just enabling non axis-aligned splits. Say someone wants to define their own axis-aligned splitter and provide a Python API. e.g. some function that chooses splits using some graph traversal instead of at random. I'm just postulating here on ideas I've had in the past for improving RF. But even this hypothetical workflow is not possible in the current setup.
In order to move #22754 along, and possibly get it into a state that can be easily refactored for a stand-alone scikit-learn-contrib package, just wanted to see if there was any consensus on this?
I would say that maintainers have not discussed and agreed in modularizing those implementations yet. Ideally, it should be modular, but I do not think that's on the agenda on any maintainer: it's not that we are against it, it's just that the few maintainers working on this part of the code-base are busy with other maintenance aspects.
If you want https://github.com/scikit-learn/scikit-learn/pull/22754 to move along, I think the best is that you try to refactor it (I would like to help, but I yet can't). Also when re-factoring private Cython code, mind sharing responsibilities for unexpected subsequent users' issue.
Hey @jjerphan thanks for the reply!
I think for me I'm faced with a few obstacles and I'm hoping for some clarity here. Happy to discuss at OH whenever they restart...
The original issue #20819 has a few concerns that were cleared up: i) citation criterion and ii) maintainability due to stiffness of the tree code. I addressed i) and then I spent a lot of time figuring out how to make Oblique Trees inherit directly from the RF code to lower the LOC to be added. Then once I got the PR setup, the dev team requested a number of experiments to support this addition. Now those are documented in the PR #22754 . Then it was said that this would require a maintainer to review the code and everyone is busy and suggested that I try to add this to scikit-learn-contrib to allow us to build off the work that was done here. Totally understand that the dev team has other goals beyond what I want and they are busy, but I want to see what's the best path forward since there's a considerable amount of work done here to make sure a 10k+ cited paper with considerable empirical evidence is compatible with scikit-learn tree code. If the concern is maintainability, I'm happy to ask @jovo if his lab is willing to maintain this part of the code since they are one lab that is invested in extending oblique trees beyond what was proposed by Breiman.
The two options that were floated to me are:
-
PRing oblique trees into scikit learn: It seems the inclusion is now mainly dependent on a review cycle, which it seems devs would prefer not to do if it is a big PR(?) I'm not sure how else to improve the code, but I'm willing to chop up the PR into smaller digestible bites. I offered to help refactor the Cython and Python code base to enable extensions that allow generalizations to the splitter but haven't gotten a reply :/.
-
Oblique Trees as a separate scikit-learn contrib package for now: it was suggested this can help me move along with some of my projects while waiting for reviews on the direct PR of oblique trees. However this then relies on the suggested refactoring I noted above. However, it seems even this refactoring to enable scikit-learn tree code to be more generalizable is facing considerable bottlenecks.
The refactoring would not change any functionality based on the gif you sent. I understand tho this is OSS and the dev team is pretty busy with consolidation of the package. So anything I can do to help either the PR itself or the scikit-learn-contrib package move along, I'm happy to do so. Also happy to contribute more to the general scikit-learn package so the devs time is freed up. I would def appreciate any guidance here.
Lmk if I'm perhaps misunderstanding anything.
Hi @adam2392,
Thank you for this long comment. I will add your contributions to discussion to the next core-developer meeting. You should be able to join it freely.
Hi @jjerphan sorry about logging in late. Today was Labor Day in the USA, so I was actually traveling during this time. I tried joining a bit after the meeting had started, but the Google meets needed someone to "let me in". I think the meeting probably ended cuz I wasn't let in, so left after a few minutes.
Wanted to follow up to see what I could do to help out with making trees more modularizable and platform-friendly?
Following up on our discussion in dev meeting:
tldr: Both scikit-survival and quantile-forests generalize RF in different ways compared to what I am proposing. I am proposing generalizations for the split function. Scikit-survival proposes a generalization of the criterion. Quantile-forests proposes a generalization of how the leaf nodes are set, which is also something honest forests generalizes.
Scikit-survival
IIUC, this package cannot subclass BaseDecisionTree rn because of fundamental differences in y variable implemented in scikit-survival. The Criterion changes, but that can be passed in directly to the Decision Tree class. Another change is the structure of the y labels variable. In scikit-survival, each sample in y is a tuple consisting of a binary event indicator as first field, and time of event or time of censoring as second field. Basically, it is whether or not sample survived and then the length of survival. Rn event_times is directly passed into LogRankCriterion here, which prevents it from using BaseDecisionTree.
I think scikit-survival can benefit from a refactoring if we enable also a private _set_criterion function, which would enable subclassing of BaseDecisionTree that alters the Criterion class as well. For example:
def _set_criterion(self, X, y):
# sets criterion class given metadata from the input X and y
This could potentially allow more exotic criterions. However, I can see this not being 100% necessary.
skurv refs
See the tree code: https://github.com/sebp/scikit-survival/blob/52c713dc8e3beb266f9cad7e6785dce5fa282f01/sksurv/tree/tree.py#L217-L230
And relevant Cython criterion that subclasses scikit-learn: https://github.com/sebp/scikit-survival/blob/master/sksurv/tree/_criterion.pyx
Quantile Forest
Quantile random forest (QRF) implements the entire forest using Cython. Fundamentally, the QRF needs to store information about all observations in its leaf nodes, rather than just the mean of the labels, which is currently done in scikit-learn's RF.
This refactoring unfortunately would not play a role here because it operates at generalizing the splitting.
However, QRF and honest forests are candidates that probably would benefit from a generalization of setting the leaves within the RF code of scikit-learn. Both QRF and honest forests use the leaves differently than that of RF, but both QRF and honest forests perform the splitting in the same way. The refactoring here would simply help QRF/honest-forests take advantage of more exotic splitting functions.
cc: @ogrisel @jjerphan @thomasjpfan Feel free to add anyone I missed.
In general, I have written up a more lengthy proposal for improving/simplifying the tree submodule. I have summarized it in this hackmd doc. Feel free to comment and take a look...
https://hackmd.io/@-5q0aS6xT4Sq6RULDW8EVw/HkpsR9hMi
cc: @thomasjpfan @jjerphan @ogrisel @glemaitre
Hi, I just wanted to follow up on this thread to see what the next steps are for any of the following:
- oblique trees PR: https://github.com/scikit-learn/scikit-learn/pull/22754
- modularizations of criterion: https://github.com/scikit-learn/scikit-learn/issues/24577
- modularization of just the Tree/TreeBuilder: this issue
- general proposal of modularizations to Tree/TreeBuilder: https://hackmd.io/@-5q0aS6xT4Sq6RULDW8EVw/HkpsR9hMi
Thanks!
Hi @adam2392,
I've seen your request. I would like to get back to a few other things first and then I will have a look at your suggestion (hopefully today). Thank you for your patience!
I read your general proposal of Tree and TreeBuilder modularizations and have some comments and suggestions. Is it possible to make you HackMD pad commentable (or even editable if relevant)?
Hey @jjerphan oops. Made the link editable now!
https://hackmd.io/@-5q0aS6xT4Sq6RULDW8EVw/HkpsR9hMi/edit
@scikit-learn/core-devs We need a decision here. Even a straight no is better then lingering. And with a bit of courage, we could get a little lift of the tree module by someone who really cares.
My main concern is that building an internal API for a 3rd party package thereby becomes a somewhat public API. On the other side, it makes integration much easier should we one day wish for it.
@adam2392 : do you want to maintain the implementations of the tree module?
Appreciate the feedback I have received and discussions here.
@adam2392 : do you want to maintain the implementations of the
treemodule?
Yes, I am always happy to contribute to scikit-learn anyway I can that is wanted by the team/community. I hope that is reflected as well from my PRs, GH activity, and Discord discussions.
This contribution is interesting but it will complexify the project.
I think I would only accept those contributions if the tree module has a higher bus factor.
I think this is quite a good idea. And our tree code has become more maintainable in the past couple of years also thanks to @thomasjpfan
Two concerns with adding more extension points:
- We have to be careful with benchmarking. The tree code is quite optimized in some places and adding some Cython abstractions can make it slower.
My main concern is that building an internal API for a 3rd party package thereby becomes a somewhat public API. On the other side, it makes integration much easier should we one day wish for it.
- This is still my concern. As long as we do not promise backward compatibility, I am okay with having more extension points.
- BTW as a work around, a third party library can pin scikit-learn as a build dependency to get the Cython tree code for a specific scikit-learn version. During runtime, their code can still work with newer versions of scikit-learn.
In general to summarize the points in this thread and others, private extension API is fine if:
- documentation clearly states no stability
- benchmarks against
maindemonstrate the trees do not lose performance in terms of runtime (accuracy is a given) - the introduced abstractions are not difficult to maintain and are not too complex
Lmk if I missed anything.
BTW as a work around, a third party library can pin scikit-learn as a build dependency to get the Cython tree code for a specific scikit-learn version. During runtime, their code can still work with newer versions of scikit-learn.
Yes this is a good point. It is actually what I ended up doing in scikit-tree by leveraging a forked-submodule of scikit-learn purely for the Cython tree code.