`aorsf` - engine: model fit fails if `mtry` is specified
Hi,
the model fit fails if mtry is specified for the aorsf-engine. If it is not specified, it works with the default engine values.
library(bonsai)
#> Loading required package: parsnip
# This works with default mtry value
rf_mod <-
rand_forest() %>%
set_engine(engine = "aorsf") %>%
set_mode(mode = "regression") %>%
set_args(min_n = 1, trees = 2, importance = "permute") %>%
fit(
formula = mpg ~ . ,
data = mtcars
)
rf_mod
#> parsnip model object
#>
#> ---------- Oblique random regression forest
#>
#> Linear combinations: Accelerated Linear regression
#> N observations: 32
#> N trees: 2
#> N predictors total: 10
#> N predictors per node: 4
#> Average leaves per tree: 7.5
#> Min observations in leaf: 1
#> OOB stat value: 0.27
#> OOB stat type: RSQ
#> Variable importance: permute
#>
#> -----------------------------------------
# Error occurs...
rf_mod_w_mtry <-
rand_forest() %>%
set_engine(engine = "aorsf") %>%
set_mode(mode = "regression") %>%
set_args(mtry = 3, min_n = 1, trees = 2, importance = "permute") %>%
fit(
formula = mpg ~ . ,
data = mtcars
)
#> Error in ncol(source): object 'x' not found
Created on 2024-08-08 with reprex v2.0.2
Thank you in advance and best regards
Thanks for the issue! Just confirming that I can reproduce this and 1) it does seem to be aorsf-specific (i.e. xgboost is not an issue) and 2) it doesn't seem to be due to any changes in parsnip (issue persists with parsnip v1.0.0). min_cols() seems to be evaluated in a different environment than its usual.
Just saw this and thought I'd check it out. The issue appears to occur in parsnip:::make_form_call:
If we modify:
if (object$engine == "spark") {
env$x <- env$data
}
to
if (object$engine %in% c("spark", "aorsf")) {
env$x <- env$data
}
the x object will be recognized when you specify mtry.
@simonpcouch - would this solution be too hacky? A more general approach might change the call to aorsf::orsf, making it so that instead of mtry = min_cols(~3, x) we use mtry = min_cols(~3, data)
This is what happens:
For the ranger engine, you end up in form_xy() where x is passed in the evaluation environment
https://github.com/tidymodels/parsnip/blob/6d4c68477ef81a9a7fe044cd628b3d80e451fe3b/R/fit_helpers.R#L148-L157
On the other hand, with the aorsf engine, you end up in form_form().
https://github.com/tidymodels/parsnip/blob/6d4c68477ef81a9a7fe044cd628b3d80e451fe3b/R/fit_helpers.R#L40-L55
The created call is
aorsf::orsf(formula = mpg ~ ., data = data, mtry = min_cols(~3,
x), n_tree = ~2, leaf_min_obs = ~1, n_thread = 1, verbose_progress = FALSE)
and the env has the elements "formula", "data", and "weights".
Leading to the error because we try to evaluate x in min_cols(~3, x) while it isn't available anywhere.
We could either do what @bcjaeger suggests, OR add a env$x <- env$data in form_form().