smartcore icon indicating copy to clipboard operation
smartcore copied to clipboard

Added per-class probability prediction for random forests

Open AlanRace opened this issue 1 year ago • 10 comments

Added a function to predict the per-class probability of each class for each observation.

let probabilities = forest.predict_probs(&data).unwrap();

probabilities is a KxC matrix, where K is the number of observations and C is the number of classes. Probabilities are calculated as the fraction of trees in the random forest that predicted the given class.

Answer to #50 for random forests.

AlanRace avatar Jul 11 '22 14:07 AlanRace

@AlanRace thank you for your contribution to Smartcore! The change looks good, but you might want to look at clippy warnings as well as increase test coverage to get this code through automatic checks

VolodymyrOrlov avatar Jul 14 '22 00:07 VolodymyrOrlov

I have added a test here but there is something wrong, please take a look: https://github.com/AlanRace/smartcore/pull/1

Mec-iS avatar Aug 24 '22 11:08 Mec-iS

@Mec-iS Thanks for supplying the test - I am guessing there was a problem due to row-major vs column-major storage of DenseMatrix? Swapping the number of rows and columns in your test and then transposing the matrix results in a passing test.

AlanRace avatar Aug 29 '22 14:08 AlanRace

Codecov Report

Merging #138 (7f7b2ed) into development (b4a807e) will increase coverage by 0.60%. The diff coverage is 100.00%.

@@               Coverage Diff               @@
##           development     #138      +/-   ##
===============================================
+ Coverage        83.40%   84.01%   +0.60%     
===============================================
  Files               78       81       +3     
  Lines             8377     8751     +374     
===============================================
+ Hits              6987     7352     +365     
- Misses            1390     1399       +9     
Impacted Files Coverage Δ
src/ensemble/random_forest_classifier.rs 75.58% <100.00%> (+4.54%) :arrow_up:
src/linalg/evd.rs 86.06% <0.00%> (ø)
src/linear/lasso_optimizer.rs 94.11% <0.00%> (ø)
src/algorithm/neighbour/mod.rs 78.57% <0.00%> (ø)
src/algorithm/neighbour/distances.rs 66.66% <0.00%> (ø)
src/preprocessing/numerical.rs 88.88% <0.00%> (ø)
src/algorithm/neighbour/fastpair.rs 95.67% <0.00%> (ø)
src/linalg/naive/dense_matrix.rs 80.11% <0.00%> (+0.89%) :arrow_up:
src/optimization/first_order/lbfgs.rs 94.44% <0.00%> (+1.58%) :arrow_up:
src/linalg/mod.rs 58.57% <0.00%> (+5.49%) :arrow_up:
... and 1 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

codecov-commenter avatar Aug 29 '22 14:08 codecov-commenter

thanks @AlanRace

it is probably better to adhere to the DenseMatrix format, so it would be nice for the method to return the transposed values or directly a DenseMatrix.

Mec-iS avatar Aug 29 '22 14:08 Mec-iS

Maybe I am misunderstanding, but predict_probs does return a DenseMatrix.

It looks like DenseMatrix::from_vec (which is called from DenseMatrix::from_array as part of your test) assumes that the given vector is in row-major form, but the entered values in the test are in column-major form.

Would you prefer that the matrix returned from predict_probs is num classes x num observations, rather than the current num observations x num classes?

AlanRace avatar Aug 30 '22 07:08 AlanRace

yeah, probably in the shape returned by from_vector and from_array is handier. thanks again

Mec-iS avatar Aug 30 '22 09:08 Mec-iS

@morenol @VolodymyrOrlov could you please take a look to the WASM test failing? it looks like we have different results for different targets. Looks like rounding works different for WASM, the results look close but not close enough.

Mec-iS avatar Aug 30 '22 10:08 Mec-iS

Hello guys, when does this fucntion will be available ? I totally need it in order to perform model ensembling !

Thanks a lot

alexis2804 avatar Oct 03 '22 09:10 alexis2804

@alexis2804 unfortunately we have problems with some tests, you can take a look at them by fetching this branch

Mec-iS avatar Oct 03 '22 11:10 Mec-iS

moved to #211 to solve conflicts

Mec-iS avatar Oct 31 '22 19:10 Mec-iS