smartcore icon indicating copy to clipboard operation
smartcore copied to clipboard

Trait bounds of RandomForest::fit conflict with those of predict and other metrics

Open DaGaMs opened this issue 5 months ago • 5 comments

I am trying to perform cross validation of a RandomForest like this:

let cv_score = cross_validate(
    RandomForestClassifier::new(),
    x_train,
    y_train,
    RandomForestClassifierParameters::default()
        .with_criterion(SplitCriterion::ClassificationError)
        .with_n_trees(*n_tree)
        .with_m(*m_feat)
        .with_min_samples_split(*m_split)
        .with_min_samples_leaf(*m_leaf),
    &KFold::default().with_n_splits(5),
    &precision
).unwrap();

x_train is a DenseMatrix<f32> and y_train is a Vec<u16> of 0 and 1.

The problem is that RandomForestClassifier::fit expects y to be Number + Ord, whereas precision expects y to be Number + RealNumber + FloatNumber. As far as I can see, RealNumber can never be Ord, so precision, f1 and roc_auc_score cannot be used in cross_validate with RandomForectClassifier directly.

For now, I worked around it by defining my own precision function that converts u16 to f32 before calling precision, but I suppose this should be fixed in the framework. The logical thing to do IMO would be to stop requiring Ord for y in RandomForrestClassifier::fit?

DaGaMs avatar Jul 13 '25 07:07 DaGaMs

the initial point of view for this was that classification labels need to be integers otherwise which element belongs to which class is ambiguous. Integers have also the advantage to be ordered.

If you think that it is necessary to have labels with reals, you can propose your implementation of another classifier like RandomForestContinuous that implements your solution.

@DanielLacina please advice as you worked on trees recently

Mec-iS avatar Jul 13 '25 11:07 Mec-iS

I agree that labels should be ordinal, but for some reason most of the metrics expect them to be reals!

In other words, I think the arguments to precision, f1 etc need to have their trait bounds changed.

DaGaMs avatar Jul 13 '25 11:07 DaGaMs

I understand but this would be a breaking change for existing programs so we need to think through if it is really necessary and eventually how to improve the situation

Mec-iS avatar Jul 13 '25 11:07 Mec-iS

I see that - I just don't know how anyone would use cross_validate with RFs in the current situation 🤷‍♂️

I suppose the problem is that precision etc are meant to also work with regression algorithms, and that's why they want real numbers?

DaGaMs avatar Jul 13 '25 11:07 DaGaMs

I understand the problem but the solution may not be the one you suggest, so it is better to think about it and find viable options for a long-term solution and not just a temporary patch. Keeping the y as a integer is the right approach I think.

For example: https://gist.github.com/Mec-iS/de4865c13712314600fc20533277b884

Mec-iS avatar Jul 13 '25 12:07 Mec-iS