linfa icon indicating copy to clipboard operation
linfa copied to clipboard

[Feature] Add Random Forest Classifier to linfa-trees

Open maxprogrammer007 opened this issue 6 months ago • 4 comments

šŸ“ Description

I would like to contribute a new module to the linfa-trees crate that implements the Random Forest algorithm for classification tasks. This will expand linfa-trees from single decision trees into ensemble learning, aligning closely with scikit-learn's functionality in Python.


šŸš€ Motivation

Random Forests are a powerful ensemble learning method used widely in classification tasks. They provide:

  • Robustness to overfitting

  • Better generalization than single trees

  • Feature importance estimates

Currently, linfa-trees provides support for single decision trees. By adding Random Forests, we unlock ensemble learning for the Rust ML ecosystem.


šŸ“ Proposed Design

šŸ”¹ New Module

A new file will be added:

bash
linfa-trees/src/decision_trees/random_forest.rs

This will include:

  • RandomForestClassifier<F: Float>

  • RandomForestParams<F> (unchecked)

  • RandomForestValidParams<F> (checked)

šŸ”¹ Trait Implementations

I will implement the following traits according to linfa conventions:

  • ParamGuard for parameter validation

  • Fit to train the forest using bootstrapped data and random feature subsetting

  • PredictInplace and Predict to perform inference via majority voting

šŸ”¹ Example

An example will be added in:

bash
linfa-trees/examples/iris_random_forest.rs

Using the Iris dataset from linfa-datasets.

šŸ”¹ Benchmark (Optional)

If approved, I can also add a benchmark using Criterion:

bash
linfa-trees/benches/random_forest.rs

šŸ“ File Integration Plan

  • src/lib.rs: Re-export random_forest::*

  • src/decision_trees/mod.rs: pub mod random_forest;

  • README.md: Update with a section on Random Forests and example usage

  • examples/iris_random_forest.rs: Demonstrates training and evaluation


šŸ“¦ API Preview

rust
let model = RandomForest::params() .n_trees(100) .feature_subsample(0.8) .max_depth(Some(10)) .fit(&dataset)?;

let predictions = model.predict(&dataset); let acc = predictions.confusion_matrix(&dataset)?.accuracy();

āœ… Conformity with CONTRIBUTING.md

  • Uses Float trait for f32/f64 compatibility

  • Follows the Params → ValidParams validation pattern

  • Implements Fit, Predict, and PredictInplace using Dataset

  • Optional serde support via feature flag

  • Will include unit tests and optionally benchmarks



šŸ™‹ā€ā™‚ļø Request

Please let me know if you're open to this contribution. I’d be happy to align with maintainers on:

  • Feature scope (classifier first, regressor later?)

  • Benchmarking standards

  • Integration strategy (e.g., reuse of DecisionTree)

Looking forward to your guidance!

maxprogrammer007 avatar May 17 '25 16:05 maxprogrammer007

Thanks for your thorough description. This looks good to me, please proceed with a PR!

relf avatar May 19 '25 09:05 relf

Sorry, just noticed previous art in #229. It would be great to take a look at it before jumping on a whole new implementation.

relf avatar May 19 '25 14:05 relf

@relf Sure i will look into #229 and afterwards i will prepare my PR

maxprogrammer007 avatar May 19 '25 15:05 maxprogrammer007

@relf

I have done a PR, please see .. all checks have been passed and i have successfully tested the module.

PR link - https://github.com/rust-ml/linfa/pull/390

maxprogrammer007 avatar May 20 '25 12:05 maxprogrammer007