yardstick integration
If you were good with a yardstick dependency, we could expand the list of metrics that can be used.
The more metrics the merrier! Would you be able to put together a PR for this too 😁? If not, I could probably get around to it later this month.
I'll be coming back to this soon (as well as some other PRs for model-based methods).
Off-hand, where do you handle the direction of the metric (i.e. maximize or minimize)?
Sounds good @topepo, there's an option called smaller_is_better in the call to vi_permute() that defaults to NULL with some logic in the case where a user picks one of the built-in metrics.
How much of the under-the-hood stuff would you be open to changing/refactoring?
For example, it would be good to put the directionality bits into the metrics file and so on.
I would also try to consolidate the metric checking into a separate function (or multiple function for user- and pre-defined metrics).
Hey @topepo, sorry for the late reply. I'm not against making useful changes to the package as long as it doesn't break backward compatibility too much. And good call out with the directionality info! Were you planning on contributing PRs for the metric checking?
A PR (or maybe more than one to make it more easily reviewed). I have a set of longish flights coming up and I'll probably work on this then. Maybe in the next few weeks.
Sounds good @topepo 👌
Just peeking through the package, it seems that using the *_vec() (e.g., roc_auc_vec()) family of functions would be relatively straightforward.
@topepo I'll be moving onto this package soon, just trying to tidy up some issues with fastshap first.
Hey @topepo, started to integrate with yardstick if you want to take a look (changes are only on the devel branch and no unit tests or anything yet, but planning to rebuild the pkgdown site with plenty of examples to help):
library(ranger)
library(vip)
#>
#> Attaching package: 'vip'
#> The following object is masked from 'package:utils':
#>
#> vi
library(yardstick)
# Complete (i.e., imputed) version of titanic data set
head(t3 <- titanic_mice[[1L]])
#> survived pclass age sex sibsp parch
#> 1 yes 1 29.00 female 0 0
#> 2 yes 1 0.92 male 1 2
#> 3 no 1 2.00 female 1 2
#> 4 no 1 30.00 male 1 2
#> 5 no 1 25.00 female 1 2
#> 6 yes 1 48.00 male 0 0
#
# Predicting class labels
#
set.seed(1120)
rfo <- ranger(survived ~ ., data = t3, probability = FALSE)
# Prediction wrapper
pfun_rfo <- function(object, newdata) {
predict(object, data = newdata)$predictions
}
pfun_rfo(rfo, newdata = head(t3))
#> [1] yes yes no no yes no
#> Levels: no yes
# Should throw an error since ROC needs vector of probabilities
set.seed(1125)
vi_permute(
rfo,
train = t3,
target = "survived",
pred_wrapper = pfun_rfo,
metric = "roc_auc",
smaller_is_better = FALSE
)
#> Warning: Consider setting the `event_level` argument when using "roc_auc" as
#> the metric; see `?vip::vi_permute` for details. Defaulting to `event_level =
#> "first"`.
#> Error in `metric_fun()` at vip/R/vi_permute.R:311:2:
#> ! `estimate` should be a numeric vector, not a `factor` vector.
#> Backtrace:
#> ▆
#> 1. ├─vip::vi_permute(...)
#> 2. └─vip:::vi_permute.default(...) at vip/R/vi_permute.R:148:2
#> 3. └─yardstick (local) metric_fun(truth = train_y, estimate = pred_wrapper(object, newdata = train_x)) at vip/R/vi_permute.R:311:2
#> 4. └─yardstick::check_prob_metric(truth, estimate, case_weights, estimator)
#> 5. └─yardstick:::validate_factor_truth_matrix_estimate(...)
#> 6. └─rlang::abort(...)
# Use yardstick function directly; need to specify `smaller_is_better`
set.seed(1125)
vi_permute(
rfo,
train = t3,
target = "survived",
pred_wrapper = pfun_rfo,
metric = accuracy_vec, # use yardstick function directly
smaller_is_better = FALSE, # needed when supplying a function
nsim = 10
)
#> # A tibble: 5 × 3
#> Variable Importance StDev
#> <chr> <dbl> <dbl>
#> 1 pclass 0.0764 0.00414
#> 2 age 0.0728 0.00837
#> 3 sex 0.221 0.0134
#> 4 sibsp 0.0348 0.00369
#> 5 parch 0.0146 0.00363
# Use built-in yardstick function; no need to specify `smaller_is_better`
set.seed(1125)
vi_permute(
rfo,
train = t3,
target = "survived",
pred_wrapper = pfun_rfo,
metric = "accuracy", # uses yardstick internally
nsim = 10
)
#> # A tibble: 5 × 3
#> Variable Importance StDev
#> <chr> <dbl> <dbl>
#> 1 pclass 0.0764 0.00414
#> 2 age 0.0728 0.00837
#> 3 sex 0.221 0.0134
#> 4 sibsp 0.0348 0.00369
#> 5 parch 0.0146 0.00363
#
# Predicting probabilites
#
set.seed(1120)
pfo <- ranger(survived ~ ., data = t3, probability = TRUE) # probability forest
# Prediction wrapper
pfun_pfo <- function(object, newdata) {
predict(object, data = newdata)$predictions[, "yes"]
}
pfun_pfo(pfo, newdata = head(t3))
#> [1] 0.9383245 0.8597809 0.6351732 0.3805831 0.7631647 0.3235726
# Use default event level; should throw a warning message
set.seed(1125)
vi_permute(
pfo,
train = t3,
target = "survived",
pred_wrapper = pfun_pfo,
metric = "roc_auc",
nsim = 10
)
#> Warning: Consider setting the `event_level` argument when using "roc_auc" as
#> the metric; see `?vip::vi_permute` for details. Defaulting to `event_level =
#> "first"`.
#> # A tibble: 5 × 3
#> Variable Importance StDev
#> <chr> <dbl> <dbl>
#> 1 pclass -0.103 0.00455
#> 2 age -0.0986 0.00678
#> 3 sex -0.238 0.0124
#> 4 sibsp -0.0336 0.00339
#> 5 parch -0.0226 0.00159
# Change the event level
set.seed(1125)
vi_permute(
pfo,
train = t3,
target = "survived",
pred_wrapper = pfun_pfo,
metric = "roc_auc",
event_level = "second",
nsim = 10
)
#> # A tibble: 5 × 3
#> Variable Importance StDev
#> <chr> <dbl> <dbl>
#> 1 pclass 0.103 0.00455
#> 2 age 0.0986 0.00678
#> 3 sex 0.238 0.0124
#> 4 sibsp 0.0336 0.00339
#> 5 parch 0.0226 0.00159
# Could also do this with a wrapper function
mfun <- function(truth, estimate) {
roc_auc_vec(truth = truth, estimate = estimate, event_level = "second")
}
set.seed(1125)
vi_permute(
pfo,
train = t3,
target = "survived",
pred_wrapper = pfun_pfo,
metric = mfun,
smaller_is_better = FALSE,
nsim = 10
)
#> # A tibble: 5 × 3
#> Variable Importance StDev
#> <chr> <dbl> <dbl>
#> 1 pclass 0.103 0.00455
#> 2 age 0.0986 0.00678
#> 3 sex 0.238 0.0124
#> 4 sibsp 0.0336 0.00339
#> 5 parch 0.0226 0.00159
Created on 2023-05-08 with reprex v2.0.2
In devel now and will be part of next release.