[Feature] Add RandomForestClassifier to linfa-trees
This PR extends the linfa-trees crate by introducing a new Random Forest classifier. It builds upon the existing Decision Tree implementation to provide an ensemble method that typically outperforms a single tree by training many trees on bootstrapped subsets of both rows and columns (features), and then aggregating their predictions via majority voting.
π Whatβs Added
-
src/decision_trees/random_forest.rs-
RandomForestParams/RandomForestValidParams: hyperparameters (n_trees,max_depth,feature_subsample,seed) with validation viaParamGuard. -
RandomForestClassifier: stores aVec<DecisionTree>and the per-tree feature indices used for training. -
Fitimplementation:- Bootstrap rows.
- Randomly subsample features per tree.
- Train each
DecisionTreeon its slice.
-
Predictimplementation:- For each tree, select the same feature slice on the test data.
- Invoke
tree.predict(&sub_x)(returnsArray1<usize>). - Accumulate votes and return the argmax for each sample.
-
-
Exports
- Updated
src/decision_trees/mod.rsandsrc/lib.rsto re-exportRandomForestParamsandRandomForestClassifier.
- Updated
-
Example
-
examples/iris_random_forest.rs: demonstrates loading the Iris dataset, training a Random Forest, printing the confusion matrix and accuracy.
-
-
Unit Test
-
tests/random_forest.rs: an integration test asserting β₯ 90 % accuracy on Iris with fixed RNG seed for reproducibility.
-
-
Dependencies
- Added
rand = "0.8"tolinfa-trees/Cargo.tomlfor RNG and sampling utilities.
- Added
-
README
- Extended
README.mdwith a βRandom Forest Classifierβ section, usage example, and run instructions.
- Extended
π§ Motivation
- Ensemble performance: Random Forests often reduce variance and improve generalization compared to a single decision tree.
- Feature importance: Subsampling features per tree provides insight into feature usefulness.
-
API consistency: Follows Linfaβs
Fit/Predict/ParamGuardconventions and integrates cleanly withDataset.
π Files Changed
algorithms/linfa-trees/
ββ Cargo.toml # + rand = "0.8"
ββ src/
β ββ decision_trees/
β β ββ algorithm.rs # no change
β β ββ mod.rs # + pub mod random_forest;
β β ββ random_forest.rs # NEW
β ββ lib.rs # + pub use decision_trees::random_forest::{β¦}
ββ examples/
β ββ iris_random_forest.rs # NEW
ββ tests/
ββ random_forest.rs # NEW
π¦ Example
cargo run --release --example iris_random_forest
classes | 0 | 1 | 2
--------------------------------
0 | 50 | 0 | 0
1 | 0 | 48 | 2
2 | 0 | 1 | 49
Accuracy: 0.97
β Checklist
- [x] Implements
ParamGuardfor hyperparameter validation - [x] Implements
Fit<Array2<F>, Array1<usize>> - [x] Implements
Predict<Array2<F>, Array1<usize>>with correct featureβslice logic - [x] Example runs without errors (
cargo run --example iris_random_forest) - [x] Unit test passes (
cargo test) - [x] README updated with usage snippet
- [x]
randdependency added
Thank you for reviewing! Iβm happy to address any feedback or suggestions.
Codecov Report
:x: Patch coverage is 37.50000% with 40 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 36.21%. Comparing base (11ea07a) to head (35b425b).
:warning: Report is 10 commits behind head on master.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| ...ms/linfa-trees/src/decision_trees/random_forest.rs | 37.50% | 40 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## master #390 +/- ##
==========================================
+ Coverage 36.09% 36.21% +0.12%
==========================================
Files 99 100 +1
Lines 6502 6566 +64
==========================================
+ Hits 2347 2378 +31
- Misses 4155 4188 +33
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Thanks for your contribution but at the moment I have no time to review properly. I still think #229 implementation is more general and was close to being merged. Did you take look? Why not starting from it?
i did checked but i felt i am more specific towards random forest.. hence i proceeded with my PR.
I have tested it across various datasets ... it works perfectly.
Please add serde support: https://github.com/joelchen/linfa/blob/master/algorithms/linfa-trees/src/decision_trees/random_forest.rs
@maxprogrammer007, I've just merged #392 which introduces linfa-ensemble and bagging algorithm.
This mirrors a bit (though far from being as complete) the scikit-learn structure.
To get proper RandomForest algorithm in this new ensemble sub-crate we need to add features sub-sampling.
So if you agree, I suggest you could reuse part of your code to implement a RandomForest in linfa-ensemble based on EnsembleLearner and get something like:
struct RandomForest<F: Float, L: Label> {
ensemble_learner: EnsembleLearner<DecisionTree<F, L>>,
bootstrap_features_ratio: f64,
feature_indices: Vec<Vec<usize>>
}
A step further, would be to manage feature subsampling directly in DecisionTree then RandomForest<F, L> would be just a thin wrapper around EnsembleLearner<DecisionTree<F, L>>.
If you proceed maybe start over with a new PR and close this one. What do you think?
@relf sure i will proceed with new PR.
superseded by #410