bonsai icon indicating copy to clipboard operation
bonsai copied to clipboard

aorsf support for `mtry_prop`

Open cgoo4 opened this issue 7 months ago • 2 comments

aorsf is a great addition to bonsai! Any chance of supporting mtry_prop?

library(tidymodels)
library(bonsai)

set.seed(1)
folds <- vfold_cv(mtcars, v = 5)

rec <- recipe(cyl ~ ., data = mtcars)

mod_lgbm <- boost_tree(mtry = tune()) |> 
  set_engine("lightgbm", count = FALSE) |>
  set_mode("regression")

mod_aorsf <- rand_forest(mtry = tune()) |> 
  set_engine("aorsf", count = FALSE) |>
  set_mode("regression")

lgbm_wflow <- workflow() |>
  add_model(mod_lgbm) |>
  add_recipe(rec)

aorsf_wflow <- workflow() |>
  add_model(mod_aorsf) |>
  add_recipe(rec)

# lightgbm supports mtry_prop
param_info <-
  lgbm_wflow |>
  extract_parameter_set_dials() |>
  update(mtry = mtry_prop(c(0, 1)))

tune_grid(
  lgbm_wflow, 
  resamples = folds, 
  param_info = param_info,
  metrics = metric_set(rmse)
  )
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics          .notes          
#>   <list>         <chr> <list>            <list>          
#> 1 <split [25/7]> Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
#> 2 <split [25/7]> Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
#> 3 <split [26/6]> Fold3 <tibble [10 × 5]> <tibble [0 × 3]>
#> 4 <split [26/6]> Fold4 <tibble [10 × 5]> <tibble [0 × 3]>
#> 5 <split [26/6]> Fold5 <tibble [10 × 5]> <tibble [0 × 3]>

# could aorsf do the same?
param_info <-
  aorsf_wflow |>
  extract_parameter_set_dials() |>
  update(mtry = mtry_prop(c(0, 1)))

tune_grid(
  aorsf_wflow, 
  resamples = folds, 
  param_info = param_info,
  metrics = metric_set(rmse)
  )
#> → A | error:   there were unrecognized arguments:
#>                  count is unrecognized - did you mean control?
#> There were issues with some computations   A: x1
#> There were issues with some computations   A: x50
#> 
#> Warning: All models failed. Run `show_notes(.Last.tune.result)` for more
#> information.
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics .notes           
#>   <list>         <chr> <list>   <list>           
#> 1 <split [25/7]> Fold1 <NULL>   <tibble [10 × 3]>
#> 2 <split [25/7]> Fold2 <NULL>   <tibble [10 × 3]>
#> 3 <split [26/6]> Fold3 <NULL>   <tibble [10 × 3]>
#> 4 <split [26/6]> Fold4 <NULL>   <tibble [10 × 3]>
#> 5 <split [26/6]> Fold5 <NULL>   <tibble [10 × 3]>
#> 
#> There were issues with some computations:
#> 
#>   - Error(s) x50: there were unrecognized arguments:   count is unrecognized - did ...
#> 
#> Run `show_notes(.Last.tune.result)` for more information.

Created on 2024-07-21 with reprex v2.1.1

cgoo4 avatar Jul 21 '24 14:07 cgoo4