vip icon indicating copy to clipboard operation
vip copied to clipboard

yardstick integration

Open topepo opened this issue 4 years ago • 1 comments

If you were good with a yardstick dependency, we could expand the list of metrics that can be used.

topepo avatar Apr 16 '21 18:04 topepo

The more metrics the merrier! Would you be able to put together a PR for this too 😁? If not, I could probably get around to it later this month.

bgreenwell avatar Apr 17 '21 21:04 bgreenwell

I'll be coming back to this soon (as well as some other PRs for model-based methods).

Off-hand, where do you handle the direction of the metric (i.e. maximize or minimize)?

topepo avatar Dec 04 '22 20:12 topepo

Sounds good @topepo, there's an option called smaller_is_better in the call to vi_permute() that defaults to NULL with some logic in the case where a user picks one of the built-in metrics.

bgreenwell avatar Dec 05 '22 01:12 bgreenwell

How much of the under-the-hood stuff would you be open to changing/refactoring?

For example, it would be good to put the directionality bits into the metrics file and so on.

I would also try to consolidate the metric checking into a separate function (or multiple function for user- and pre-defined metrics).

topepo avatar Jan 06 '23 20:01 topepo

Hey @topepo, sorry for the late reply. I'm not against making useful changes to the package as long as it doesn't break backward compatibility too much. And good call out with the directionality info! Were you planning on contributing PRs for the metric checking?

bgreenwell avatar Jan 18 '23 21:01 bgreenwell

A PR (or maybe more than one to make it more easily reviewed). I have a set of longish flights coming up and I'll probably work on this then. Maybe in the next few weeks.

topepo avatar Jan 19 '23 20:01 topepo

Sounds good @topepo 👌

bgreenwell avatar Jan 20 '23 16:01 bgreenwell

Just peeking through the package, it seems that using the *_vec() (e.g., roc_auc_vec()) family of functions would be relatively straightforward.

bgreenwell avatar Jan 31 '23 21:01 bgreenwell

@topepo I'll be moving onto this package soon, just trying to tidy up some issues with fastshap first.

bgreenwell avatar Apr 30 '23 16:04 bgreenwell

Hey @topepo, started to integrate with yardstick if you want to take a look (changes are only on the devel branch and no unit tests or anything yet, but planning to rebuild the pkgdown site with plenty of examples to help):

library(ranger)
library(vip)
#> 
#> Attaching package: 'vip'
#> The following object is masked from 'package:utils':
#> 
#>     vi
library(yardstick)

# Complete (i.e., imputed) version of titanic data set
head(t3 <- titanic_mice[[1L]])
#>   survived pclass   age    sex sibsp parch
#> 1      yes      1 29.00 female     0     0
#> 2      yes      1  0.92   male     1     2
#> 3       no      1  2.00 female     1     2
#> 4       no      1 30.00   male     1     2
#> 5       no      1 25.00 female     1     2
#> 6      yes      1 48.00   male     0     0

#
# Predicting class labels
#

set.seed(1120)
rfo <- ranger(survived ~ ., data = t3, probability = FALSE)

# Prediction wrapper
pfun_rfo <- function(object, newdata) {
  predict(object, data = newdata)$predictions
}
pfun_rfo(rfo, newdata = head(t3))
#> [1] yes yes no  no  yes no 
#> Levels: no yes

# Should throw an error since ROC needs vector of probabilities
set.seed(1125)
vi_permute(
  rfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_rfo,
  metric = "roc_auc",
  smaller_is_better = FALSE
)
#> Warning: Consider setting the `event_level` argument when using "roc_auc" as
#> the metric; see `?vip::vi_permute` for details. Defaulting to `event_level =
#> "first"`.
#> Error in `metric_fun()` at vip/R/vi_permute.R:311:2:
#> ! `estimate` should be a numeric vector, not a `factor` vector.
#> Backtrace:
#>     ▆
#>  1. ├─vip::vi_permute(...)
#>  2. └─vip:::vi_permute.default(...) at vip/R/vi_permute.R:148:2
#>  3.   └─yardstick (local) metric_fun(truth = train_y, estimate = pred_wrapper(object, newdata = train_x)) at vip/R/vi_permute.R:311:2
#>  4.     └─yardstick::check_prob_metric(truth, estimate, case_weights, estimator)
#>  5.       └─yardstick:::validate_factor_truth_matrix_estimate(...)
#>  6.         └─rlang::abort(...)

# Use yardstick function directly; need to specify `smaller_is_better`
set.seed(1125)
vi_permute(
  rfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_rfo,
  metric = accuracy_vec,      # use yardstick function directly
  smaller_is_better = FALSE,  # needed when supplying a function
  nsim = 10
)
#> # A tibble: 5 × 3
#>   Variable Importance   StDev
#>   <chr>         <dbl>   <dbl>
#> 1 pclass       0.0764 0.00414
#> 2 age          0.0728 0.00837
#> 3 sex          0.221  0.0134 
#> 4 sibsp        0.0348 0.00369
#> 5 parch        0.0146 0.00363

# Use built-in yardstick function; no need to specify `smaller_is_better`
set.seed(1125)
vi_permute(
  rfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_rfo,
  metric = "accuracy",      # uses yardstick internally
  nsim = 10
)
#> # A tibble: 5 × 3
#>   Variable Importance   StDev
#>   <chr>         <dbl>   <dbl>
#> 1 pclass       0.0764 0.00414
#> 2 age          0.0728 0.00837
#> 3 sex          0.221  0.0134 
#> 4 sibsp        0.0348 0.00369
#> 5 parch        0.0146 0.00363

#
# Predicting probabilites
#

set.seed(1120)
pfo <- ranger(survived ~ ., data = t3, probability = TRUE)  # probability forest

# Prediction wrapper
pfun_pfo <- function(object, newdata) {
  predict(object, data = newdata)$predictions[, "yes"]
}
pfun_pfo(pfo, newdata = head(t3))
#> [1] 0.9383245 0.8597809 0.6351732 0.3805831 0.7631647 0.3235726

# Use default event level; should throw a warning message
set.seed(1125)
vi_permute(
  pfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_pfo,
  metric = "roc_auc",
  nsim = 10
)
#> Warning: Consider setting the `event_level` argument when using "roc_auc" as
#> the metric; see `?vip::vi_permute` for details. Defaulting to `event_level =
#> "first"`.
#> # A tibble: 5 × 3
#>   Variable Importance   StDev
#>   <chr>         <dbl>   <dbl>
#> 1 pclass      -0.103  0.00455
#> 2 age         -0.0986 0.00678
#> 3 sex         -0.238  0.0124 
#> 4 sibsp       -0.0336 0.00339
#> 5 parch       -0.0226 0.00159

# Change the event level
set.seed(1125)
vi_permute(
  pfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_pfo,
  metric = "roc_auc",
  event_level = "second",
  nsim = 10
)
#> # A tibble: 5 × 3
#>   Variable Importance   StDev
#>   <chr>         <dbl>   <dbl>
#> 1 pclass       0.103  0.00455
#> 2 age          0.0986 0.00678
#> 3 sex          0.238  0.0124 
#> 4 sibsp        0.0336 0.00339
#> 5 parch        0.0226 0.00159

# Could also do this with a wrapper function
mfun <- function(truth, estimate) {
  roc_auc_vec(truth = truth, estimate = estimate, event_level = "second")
}
set.seed(1125)
vi_permute(
  pfo,
  train = t3,
  target = "survived",
  pred_wrapper = pfun_pfo,
  metric = mfun,
  smaller_is_better = FALSE,
  nsim = 10
)
#> # A tibble: 5 × 3
#>   Variable Importance   StDev
#>   <chr>         <dbl>   <dbl>
#> 1 pclass       0.103  0.00455
#> 2 age          0.0986 0.00678
#> 3 sex          0.238  0.0124 
#> 4 sibsp        0.0336 0.00339
#> 5 parch        0.0226 0.00159

Created on 2023-05-08 with reprex v2.0.2

bgreenwell avatar May 08 '23 16:05 bgreenwell

In devel now and will be part of next release.

bgreenwell avatar Jul 12 '23 20:07 bgreenwell