linfa icon indicating copy to clipboard operation
linfa copied to clipboard

[Feature] Add RandomForestClassifier to linfa-trees

Open maxprogrammer007 opened this issue 11 months ago β€’ 6 comments

This PR extends the linfa-trees crate by introducing a new Random Forest classifier. It builds upon the existing Decision Tree implementation to provide an ensemble method that typically outperforms a single tree by training many trees on bootstrapped subsets of both rows and columns (features), and then aggregating their predictions via majority voting.


πŸš€ What’s Added

  1. src/decision_trees/random_forest.rs

    • RandomForestParams / RandomForestValidParams: hyperparameters (n_trees, max_depth, feature_subsample, seed) with validation via ParamGuard.

    • RandomForestClassifier: stores a Vec<DecisionTree> and the per-tree feature indices used for training.

    • Fit implementation:

      • Bootstrap rows.
      • Randomly subsample features per tree.
      • Train each DecisionTree on its slice.
    • Predict implementation:

      • For each tree, select the same feature slice on the test data.
      • Invoke tree.predict(&sub_x) (returns Array1<usize>).
      • Accumulate votes and return the argmax for each sample.
  2. Exports

    • Updated src/decision_trees/mod.rs and src/lib.rs to re-export RandomForestParams and RandomForestClassifier.
  3. Example

    • examples/iris_random_forest.rs: demonstrates loading the Iris dataset, training a Random Forest, printing the confusion matrix and accuracy.
  4. Unit Test

    • tests/random_forest.rs: an integration test asserting β‰₯ 90 % accuracy on Iris with fixed RNG seed for reproducibility.
  5. Dependencies

    • Added rand = "0.8" to linfa-trees/Cargo.toml for RNG and sampling utilities.
  6. README

    • Extended README.md with a β€œRandom Forest Classifier” section, usage example, and run instructions.

🧐 Motivation

  • Ensemble performance: Random Forests often reduce variance and improve generalization compared to a single decision tree.
  • Feature importance: Subsampling features per tree provides insight into feature usefulness.
  • API consistency: Follows Linfa’s Fit / Predict / ParamGuard conventions and integrates cleanly with Dataset.

πŸ” Files Changed

algorithms/linfa-trees/
β”œβ”€ Cargo.toml            # + rand = "0.8"
β”œβ”€ src/
β”‚  β”œβ”€ decision_trees/
β”‚  β”‚  β”œβ”€ algorithm.rs    # no change
β”‚  β”‚  β”œβ”€ mod.rs          # + pub mod random_forest;
β”‚  β”‚  └─ random_forest.rs # NEW
β”‚  └─ lib.rs             # + pub use decision_trees::random_forest::{…}
β”œβ”€ examples/
β”‚  └─ iris_random_forest.rs # NEW
└─ tests/
   └─ random_forest.rs   # NEW

πŸ“¦ Example

cargo run --release --example iris_random_forest
classes    | 0  | 1  | 2
--------------------------------
0          | 50 |  0 |  0
1          |  0 | 48 |  2
2          |  0 |  1 | 49

Accuracy: 0.97

βœ… Checklist

  • [x] Implements ParamGuard for hyperparameter validation
  • [x] Implements Fit<Array2<F>, Array1<usize>>
  • [x] Implements Predict<Array2<F>, Array1<usize>> with correct feature‐slice logic
  • [x] Example runs without errors (cargo run --example iris_random_forest)
  • [x] Unit test passes (cargo test)
  • [x] README updated with usage snippet
  • [x] rand dependency added

Thank you for reviewing! I’m happy to address any feedback or suggestions.

maxprogrammer007 avatar May 20 '25 11:05 maxprogrammer007

Codecov Report

:x: Patch coverage is 37.50000% with 40 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 36.21%. Comparing base (11ea07a) to head (35b425b). :warning: Report is 10 commits behind head on master.

Files with missing lines Patch % Lines
...ms/linfa-trees/src/decision_trees/random_forest.rs 37.50% 40 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #390      +/-   ##
==========================================
+ Coverage   36.09%   36.21%   +0.12%     
==========================================
  Files          99      100       +1     
  Lines        6502     6566      +64     
==========================================
+ Hits         2347     2378      +31     
- Misses       4155     4188      +33     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar May 20 '25 12:05 codecov[bot]

Thanks for your contribution but at the moment I have no time to review properly. I still think #229 implementation is more general and was close to being merged. Did you take look? Why not starting from it?

relf avatar May 21 '25 17:05 relf

i did checked but i felt i am more specific towards random forest.. hence i proceeded with my PR.

I have tested it across various datasets ... it works perfectly.

maxprogrammer007 avatar May 21 '25 18:05 maxprogrammer007

Please add serde support: https://github.com/joelchen/linfa/blob/master/algorithms/linfa-trees/src/decision_trees/random_forest.rs

joelchen avatar May 25 '25 22:05 joelchen

@maxprogrammer007, I've just merged #392 which introduces linfa-ensemble and bagging algorithm. This mirrors a bit (though far from being as complete) the scikit-learn structure. To get proper RandomForest algorithm in this new ensemble sub-crate we need to add features sub-sampling.

So if you agree, I suggest you could reuse part of your code to implement a RandomForest in linfa-ensemble based on EnsembleLearner and get something like:

struct RandomForest<F: Float, L: Label> {
    ensemble_learner: EnsembleLearner<DecisionTree<F, L>>,
    bootstrap_features_ratio: f64,
    feature_indices: Vec<Vec<usize>>
}

A step further, would be to manage feature subsampling directly in DecisionTree then RandomForest<F, L> would be just a thin wrapper around EnsembleLearner<DecisionTree<F, L>>. If you proceed maybe start over with a new PR and close this one. What do you think?

relf avatar May 26 '25 13:05 relf

@relf sure i will proceed with new PR.

maxprogrammer007 avatar May 26 '25 13:05 maxprogrammer007

superseded by #410

relf avatar Oct 23 '25 14:10 relf