yardstick
yardstick copied to clipboard
[Feature request or question] Provide option for metric analysis by class in a multi-class setting
Not sure if already available easily, but when having a multi-class classification problem, then it would be good to automatically get the performance on several metrics separately for each class, not just combined. Currently, metric_set
just returns a combined/aggregated (?) analysis.
(Sorry for the long reprex):
iris
model_recipe <- recipes::recipe(Species ~ ., data = iris)
# Create a workflow
model_final <- parsnip::naive_Bayes(Laplace = 1) |>
parsnip::set_mode("classification") |>
parsnip::set_engine("klaR",
prior = rep(1/3, 3),
usekernel = FALSE)
model_final_wf <- workflows::workflow() |>
workflows::add_recipe(model_recipe) |>
workflows::add_model(model_final)
train_fit <- model_final_wf |>
generics::fit(data = iris)
# Add predictions
train_predictions <- predict(train_fit, iris, type = "prob") |>
dplyr::mutate(class_pred = as.factor(apply(dplyr::across(tidyselect::everything()), 1, which.max))) |>
dplyr::bind_cols(iris)
# Check some metrics
multimetric <- yardstick::metric_set(yardstick::f_meas,
yardstick::accuracy,
yardstick::bal_accuracy,
yardstick::sens,
yardstick::spec,
yardstick::precision,
yardstick::recall,
yardstick::ppv,
yardstick::npv)
train_predictions |>
dplyr::mutate(Species = dplyr::case_when(.data$Species == "setosa"~ 1,
.data$Species == "versicolor" ~ 2,
.data$Species == "virginica" ~ 3),
Species = as.factor(.data$Species)) |>
multimetric(truth = .data$Species,
estimate = .data$class_pred)
# A tibble: 9 Γ 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 f_meas macro 0.96
2 accuracy multiclass 0.96
3 bal_accuracy macro 0.97
4 sens macro 0.96
5 spec macro 0.98
6 precision macro 0.96
7 recall macro 0.96
8 ppv macro 0.96
9 npv macro 0.98
So additionally to thsi analysis, it would be great to be able to see e.g. accuracy for each class. Is that already doable, am I missing an easy way of getting there, or could this be added since it's a pretty common and standard step to check class performance in a multi-class setting.
Hello @deschen1 π
You should already be able to do this with the current version of yardstick. All the metrics works on grouped data.frames so you can call
train_predictions |>
group_by(Species) |>
multimetric(truth = Species, estimate = .pred_class)
To get the results you want. I also took the liberaty to show the augment()
which is a great function to combine your original data with predictions. For classification models it will return predicted probabilities and classes. You will get some warnings because if you stratify your calculation by outcome, you will have some levels with no true events, given undefined results.
library(tidymodels)
library(discrim)
model_recipe <- recipe(Species ~ ., data = iris)
# Create a workflow
model_final <- naive_Bayes(Laplace = 1) |>
set_mode("classification") |>
set_engine("klaR", prior = rep(1/3, 3), usekernel = FALSE)
model_final_wf <- workflow() |>
add_recipe(model_recipe) |>
add_model(model_final)
train_fit <- fit(model_final_wf, data = iris)
# Add predictions
train_predictions <- augment(train_fit, iris)
train_predictions
#> # A tibble: 150 Γ 9
#> Sepal.Lenβ¦ΒΉ Sepalβ¦Β² Petalβ¦Β³ Petalβ¦β΄ Species .predβ¦β΅ .predβ¦βΆ .pred_β¦β· .pred_β¦βΈ
#> <dbl> <dbl> <dbl> <dbl> <fct> <fct> <dbl> <dbl> <dbl>
#> 1 5.1 3.5 1.4 0.2 setosa setosa 1 2.98e-18 2.15e-25
#> 2 4.9 3 1.4 0.2 setosa setosa 1 3.17e-17 6.94e-25
#> 3 4.7 3.2 1.3 0.2 setosa setosa 1 2.37e-18 7.24e-26
#> 4 4.6 3.1 1.5 0.2 setosa setosa 1 3.07e-17 8.69e-25
#> 5 5 3.6 1.4 0.2 setosa setosa 1 1.02e-18 8.89e-26
#> 6 5.4 3.9 1.7 0.4 setosa setosa 1.00 2.72e-14 4.34e-21
#> 7 4.6 3.4 1.4 0.3 setosa setosa 1 2.32e-17 7.99e-25
#> 8 5 3.4 1.5 0.2 setosa setosa 1 1.39e-17 8.17e-25
#> 9 4.4 2.9 1.4 0.2 setosa setosa 1 1.99e-17 3.61e-25
#> 10 4.9 3.1 1.5 0.1 setosa setosa 1 7.38e-18 3.62e-25
#> # β¦ with 140 more rows, and abbreviated variable names ΒΉβSepal.Length,
#> # Β²βSepal.Width, Β³βPetal.Length, β΄βPetal.Width, β΅β.pred_class, βΆβ.pred_setosa,
#> # β·β.pred_versicolor, βΈβ.pred_virginica
# Check some metrics
multimetric <- metric_set(
yardstick::f_meas,
yardstick::accuracy,
yardstick::bal_accuracy,
yardstick::sens,
yardstick::spec,
yardstick::precision,
yardstick::recall,
yardstick::ppv,
yardstick::npv
)
train_predictions |>
group_by(Species) |>
multimetric(truth = Species, estimate = .pred_class)
#> # A tibble: 27 Γ 4
#> Species .metric .estimator .estimate
#> <fct> <chr> <chr> <dbl>
#> 1 setosa f_meas macro 1
#> 2 versicolor f_meas macro 0.969
#> 3 virginica f_meas macro 0.969
#> 4 setosa accuracy multiclass 1
#> 5 versicolor accuracy multiclass 0.94
#> 6 virginica accuracy multiclass 0.94
#> 7 setosa bal_accuracy macro NA
#> 8 versicolor bal_accuracy macro NA
#> 9 virginica bal_accuracy macro NA
#> 10 setosa sens macro 1
#> # β¦ with 17 more rows
Created on 2022-10-25 with reprex v2.0.2
Thanks a lot @EmilHvitfeldt . One more question: how could I add the roc_auc for each class?
Simply adding yardstick::roc_auc
to the multimetric set does not work, which is, I think, because it is a metric that takes the class probabilities. However, even if I do a standalone metric it doesn't work:
train_predictions |>
roc_auc(truth = Species,
estimate = .pred_setosa:.pred_virginica)
Adding a group_by(Species)
to it gives a lot of warnings and returns NaN for each class.
and only gives me one combined area. However, when using roc_curve
+ autoplot
I will get one curve (and hence AUC) for each level.
I think this is the same as https://github.com/tidymodels/yardstick/issues/4
i.e. it sounds like you want one-vs-all output (with regard to factor levels). Like:
# truth
"a" "b" "c" "c" "a"
# for a, recode as
"y" "n" "n" "n" "y"
# for b, recode as
"n" "y" "n" "n" "n"
# for c, recode as
"n" "n" "y" "y" "n"
Then do the same thing for estimate
, and compute 3 separate binary accuracy()
calculations for each set of recoded values. yardstick:::one_vs_all_impl()
is an internal helper that does this because we need to do those computations to compute some macro/micro estimates, but the kind of output that we'd get back doesn't natively fit into the rest of our API so I decided not to add it.
It is possible it could be a separate helper function that just wouldn't be usable in tuning and couldn't be combined in a metric set. It could maybe look something like:
one_vs_all <- function(data, truth, estimate, fn) {
# impl
}
one_vs_all(data, Species, pred, metric_set(accuracy, precision))
#> # A tibble: 6 Γ 4
#> .level .metric .estimator .estimate
#> <chr> <chr> <chr> <dbl>
#> 1 setosa accuracy binary 0.7
#> 2 setosa precision binary 0.4
#> 3 versicolor accuracy binary 0.8
#> 4 versicolor precision binary 0.3
#> 5 virginica accuracy binary 0.2
#> 6 virginica precision binary 0.5
We'd have to consider how the API would look for numeric/class/class-prob metrics, because you supply estimate
for some and ...
for class-prob metrics.
Thanks Davis!
Not entirely sure if what you are saying is that I can get what I want by using these already implemented (?) functions or if this is something you are going to add to the package?
You cannot currently do it.
I am not sure if we should expose these utilities or not, but it has come up a few times so it might be worth it to expose them in a limited way.
Got it. I think it would generally be a great idea to offer class-specific metrics in a multi-class setting. From what I've seen in other tools and for certain analyses, it is a common but also very important task.
E.g. in a market research study where you segment/cluster customers, there might be a very important class that you want to predict very well even if that means you are missing out on some other classes. But you could only judge by seeing the class-specific metrics.
Hello @deschen1 π
You should already be able to do this with the current version of yardstick. All the metrics works on grouped data.frames so you can call
train_predictions |> group_by(Species) |> multimetric(truth = Species, estimate = .pred_class)
To get the results you want. I also took the liberaty to show the
augment()
which is a great function to combine your original data with predictions. For classification models it will return predicted probabilities and classes. You will get some warnings because if you stratify your calculation by outcome, you will have some levels with no true events, given undefined results.library(tidymodels) library(discrim) model_recipe <- recipe(Species ~ ., data = iris) # Create a workflow model_final <- naive_Bayes(Laplace = 1) |> set_mode("classification") |> set_engine("klaR", prior = rep(1/3, 3), usekernel = FALSE) model_final_wf <- workflow() |> add_recipe(model_recipe) |> add_model(model_final) train_fit <- fit(model_final_wf, data = iris) # Add predictions train_predictions <- augment(train_fit, iris) train_predictions #> # A tibble: 150 Γ 9 #> Sepal.Lenβ¦ΒΉ Sepalβ¦Β² Petalβ¦Β³ Petalβ¦β΄ Species .predβ¦β΅ .predβ¦βΆ .pred_β¦β· .pred_β¦βΈ #> <dbl> <dbl> <dbl> <dbl> <fct> <fct> <dbl> <dbl> <dbl> #> 1 5.1 3.5 1.4 0.2 setosa setosa 1 2.98e-18 2.15e-25 #> 2 4.9 3 1.4 0.2 setosa setosa 1 3.17e-17 6.94e-25 #> 3 4.7 3.2 1.3 0.2 setosa setosa 1 2.37e-18 7.24e-26 #> 4 4.6 3.1 1.5 0.2 setosa setosa 1 3.07e-17 8.69e-25 #> 5 5 3.6 1.4 0.2 setosa setosa 1 1.02e-18 8.89e-26 #> 6 5.4 3.9 1.7 0.4 setosa setosa 1.00 2.72e-14 4.34e-21 #> 7 4.6 3.4 1.4 0.3 setosa setosa 1 2.32e-17 7.99e-25 #> 8 5 3.4 1.5 0.2 setosa setosa 1 1.39e-17 8.17e-25 #> 9 4.4 2.9 1.4 0.2 setosa setosa 1 1.99e-17 3.61e-25 #> 10 4.9 3.1 1.5 0.1 setosa setosa 1 7.38e-18 3.62e-25 #> # β¦ with 140 more rows, and abbreviated variable names ΒΉβSepal.Length, #> # Β²βSepal.Width, Β³βPetal.Length, β΄βPetal.Width, β΅β.pred_class, βΆβ.pred_setosa, #> # β·β.pred_versicolor, βΈβ.pred_virginica # Check some metrics multimetric <- metric_set( yardstick::f_meas, yardstick::accuracy, yardstick::bal_accuracy, yardstick::sens, yardstick::spec, yardstick::precision, yardstick::recall, yardstick::ppv, yardstick::npv ) train_predictions |> group_by(Species) |> multimetric(truth = Species, estimate = .pred_class) #> # A tibble: 27 Γ 4 #> Species .metric .estimator .estimate #> <fct> <chr> <chr> <dbl> #> 1 setosa f_meas macro 1 #> 2 versicolor f_meas macro 0.969 #> 3 virginica f_meas macro 0.969 #> 4 setosa accuracy multiclass 1 #> 5 versicolor accuracy multiclass 0.94 #> 6 virginica accuracy multiclass 0.94 #> 7 setosa bal_accuracy macro NA #> 8 versicolor bal_accuracy macro NA #> 9 virginica bal_accuracy macro NA #> 10 setosa sens macro 1 #> # β¦ with 17 more rows
Created on 2022-10-25 with reprex v2.0.2
Hi @EmilHvitfeldt, is it possible with this solution to also get the standard error of the metrics? Basically, is it possible to get the same style of output that collect_metrics
provides, but per class?
Hi @EmilHvitfeldt , @DavisVaughan
Following up the above information on category-wise metrics, I've found a potential issue when using the estimator = "macro_weighted".
I've used a slightly different data set:
library(tidymodels)
library(palmerpenguins)
conflicts_prefer(palmerpenguins::penguins)
penguins_split <- initial_split(penguins, strata = "species")
penguins_train <- training(penguins_split)
rf_spec <- rand_forest() |>
set_mode("classification")
results <-
workflow(preprocessor = recipe(species ~ island + year, data = penguins_train),
spec =rf_spec |>
last_fit(penguins_split)
results
# A tibble: 86 Γ 8
id .pred_Adelie .pred_Chinstrap .pred_Gentoo .row .pred_class species .config
<chr> <dbl> <dbl> <dbl> <int> <fct> <fct> <chr>
1 train/test split 0.788 0.152 0.0599 3 Adelie Adelie Preprocessor1_Model1
2 train/test split 0.788 0.152 0.0599 9 Adelie Adelie Preprocessor1_Model1
3 train/test split 0.788 0.152 0.0599 15 Adelie Adelie Preprocessor1_Model1
4 train/test split 0.788 0.152 0.0599 19 Adelie Adelie Preprocessor1_Model1
5 train/test split 0.457 0.484 0.0599 31 Chinstrap Adelie Preprocessor1_Model1
6 train/test split 0.457 0.484 0.0599 32 Chinstrap Adelie Preprocessor1_Model1
7 train/test split 0.457 0.484 0.0599 36 Chinstrap Adelie Preprocessor1_Model1
8 train/test split 0.457 0.484 0.0599 40 Chinstrap Adelie Preprocessor1_Model1
9 train/test split 0.457 0.484 0.0599 43 Chinstrap Adelie Preprocessor1_Model1
10 train/test split 0.457 0.484 0.0599 49 Chinstrap Adelie Preprocessor1_Model1
# ? 76 more rows
Using estimator = "macro":
results |>
collect_predictions() |>
precision(
species ,
.pred_class,
estimator = "macro"
)
# A tibble: 1 Γ 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 precision macro 0.603
results |>
collect_predictions() |>
group_by(.pred_class) |>
precision(
species ,
.pred_class,
estimator = "macro"
)
# A tibble: 3 Γ 4
.pred_class .metric .estimator .estimate
<fct> <chr> <chr> <dbl>
1 Adelie precision macro 0.636
2 Chinstrap precision macro 0.417
3 Gentoo precision macro 0.756
Using the estimator = "macro_weighted":
results |>
collect_predictions() |>
precision(
species ,
.pred_class,
estimator = "macro_weighted"
)
# A tibble: 1 Γ 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 precision macro_weighted 0.636
results |>
collect_predictions() |>
group_by(.pred_class) |>
precision(
species ,
.pred_class,
estimator = "macro_weighted"
)
# A tibble: 3 Γ 4
.pred_class .metric .estimator .estimate
<fct> <chr> <chr> <dbl>
1 Adelie precision macro_weighted 0.636
2 Chinstrap precision macro_weighted 0.417
3 Gentoo precision macro_weighted 0.756
The values on a class-level for "macro_weighted" are identical to "macro".
On an overall-level its seems to be correct.
I implemented a function to calculate the one-vs-all per-class metrics:
library(tidymodels)
#> ββ Attaching packages ββββββββββββββββββββββββββββββββββββββ tidymodels 1.1.1 ββ
#> β broom 1.0.5 β recipes 1.0.10
#> β dials 1.2.1 β rsample 1.2.0
#> β dplyr 1.1.4 β tibble 3.2.1
#> β ggplot2 3.5.0 β tidyr 1.3.1
#> β infer 1.0.6 β tune 1.1.2
#> β modeldata 1.3.0 β workflows 1.1.4
#> β parsnip 1.2.0 β workflowsets 1.0.1
#> β purrr 1.0.2 β yardstick 1.3.0
#> ββ Conflicts βββββββββββββββββββββββββββββββββββββββββ tidymodels_conflicts() ββ
#> β purrr::discard() masks scales::discard()
#> β dplyr::filter() masks stats::filter()
#> β dplyr::lag() masks stats::lag()
#> β recipes::step() masks stats::step()
#> β’ Search for functions across packages at https://www.tidymodels.org/find/
library(ranger)
ova_metrics <- function(x, truth, estimate, metric_set) {
x %>%
dplyr::mutate(
truth_ova = purrr::map({{ truth }}, ~ {
case_when(
levels({{ truth }}) %in% .x ~ as.character(.x),
is.na(.x) ~ NA_character_,
.default = "class_0"
) %>%
rlang::set_names(levels({{ truth }}))
}),
estimate_ova = purrr::map({{ estimate }}, ~ {
case_when(
levels({{ estimate }}) %in% .x ~ as.character(.x),
is.na(.x) ~ NA_character_,
.default = "class_0"
) %>%
rlang::set_names(levels({{ estimate }}))
})
) %>%
tidyr::unnest_longer(col = c(truth_ova, estimate_ova)) %>%
dplyr::mutate(class_group = purrr::map2_chr(truth_ova_id, estimate_ova_id, unique)) %>%
tidyr::nest(.by = class_group) %>%
dplyr::mutate(
data = data %>%
purrr::map(~ dplyr::mutate(.x, dplyr::across(
dplyr::matches("ova"),
~ factor(.x) %>%
forcats::fct_expand("class_0") %>%
forcats::fct_relevel("class_0", after = Inf)
))) %>%
purrr::map(metric_set, truth = truth_ova, estimate = estimate_ova) %>%
suppressWarnings()
) %>%
tidyr::unnest(cols = data)
}
model_recipe <- recipe(Species ~ ., data = iris)
model_final <- rand_forest(
mode = "classification",
engine = "ranger",
mtry = 3,
min_n = 20,
trees = 500
) %>%
set_engine("ranger", importance = "impurity")
model_final_wf <- workflow() |>
add_recipe(model_recipe) |>
add_model(model_final)
set.seed(2024)
train_fit <- fit(model_final_wf, data = iris)
train_predictions <- augment(train_fit, iris)
train_predictions %>% conf_mat(Species, .pred_class)
#> Truth
#> Prediction setosa versicolor virginica
#> setosa 50 0 0
#> versicolor 0 47 1
#> virginica 0 3 49
per_class_mset <-
metric_set(accuracy, sensitivity, specificity, f_meas, bal_accuracy, kap)
train_predictions %>%
per_class_mset(truth = Species, estimate = .pred_class)
#> # A tibble: 6 Γ 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy multiclass 0.973
#> 2 sensitivity macro 0.973
#> 3 specificity macro 0.987
#> 4 f_meas macro 0.973
#> 5 bal_accuracy macro 0.98
#> 6 kap multiclass 0.96
train_predictions %>%
ova_metrics(truth = Species, estimate = .pred_class, metric_set = per_class_mset)
#> # A tibble: 18 Γ 4
#> class_group .metric .estimator .estimate
#> <chr> <chr> <chr> <dbl>
#> 1 setosa accuracy binary 1
#> 2 setosa sensitivity binary 1
#> 3 setosa specificity binary 1
#> 4 setosa f_meas binary 1
#> 5 setosa bal_accuracy binary 1
#> 6 setosa kap binary 1
#> 7 versicolor accuracy binary 0.973
#> 8 versicolor sensitivity binary 0.94
#> 9 versicolor specificity binary 0.99
#> 10 versicolor f_meas binary 0.959
#> 11 versicolor bal_accuracy binary 0.965
#> 12 versicolor kap binary 0.939
#> 13 virginica accuracy binary 0.973
#> 14 virginica sensitivity binary 0.98
#> 15 virginica specificity binary 0.97
#> 16 virginica f_meas binary 0.961
#> 17 virginica bal_accuracy binary 0.975
#> 18 virginica kap binary 0.941
Created on 2024-04-02 by the reprex package (v0.3.0)
I've also come up with a solution:
- [x] works with grouped data
- [x] keeps metrics with multiclass support as-is
library(tidymodels)
library(ranger)
model_recipe <- recipe(Species ~ ., data = iris)
model_final <- rand_forest(
mode = "classification",
engine = "ranger",
mtry = 3,
min_n = 20,
trees = 500
) %>%
set_engine("ranger", importance = "impurity")
model_final_wf <- workflow() |>
add_recipe(model_recipe) |>
add_model(model_final)
set.seed(2024)
train_fit <- fit(model_final_wf, data = iris)
train_predictions <- augment(train_fit, iris)
train_predictions %>% conf_mat(Species, .pred_class)
#> Truth
#> Prediction setosa versicolor virginica
#> setosa 50 0 0
#> versicolor 0 47 2
#> virginica 0 3 48
per_class_mset <-
metric_set(accuracy, sensitivity, specificity, f_meas, bal_accuracy, kap)
train_predictions |>
metric_by_event_by(per_class_mset, truth = Species, estimate = .pred_class)
#> # A tibble: 14 Γ 4
#> .class .metric .estimator .estimate
#> <chr> <chr> <chr> <dbl>
#> 1 setosa sensitivity event 1
#> 2 setosa specificity event 1
#> 3 setosa f_meas event 1
#> 4 setosa bal_accuracy event 1
#> 5 versicolor sensitivity event 0.94
#> 6 versicolor specificity event 0.98
#> 7 versicolor f_meas event 0.949
#> 8 versicolor bal_accuracy event 0.96
#> 9 virginica sensitivity event 0.96
#> 10 virginica specificity event 0.97
#> 11 virginica f_meas event 0.950
#> 12 virginica bal_accuracy event 0.965
#> 13 NA accuracy multiclass 0.967
#> 14 NA kap multiclass 0.95
The function:
metric_by_event_by <- function(data, metric_set, truth, estimate, ...) {
# Only for class_prob_metric_set
stopifnot(inherits(metric_set , "class_prob_metric_set"))
# Make a call with the metric set
cl <- match.call()
cl[[1]] <- as.name(as.character(cl$metric_set))
cl$metric_set <- NULL
cl$event_level <- "first"
# Find names and levels
truth_name <- as.character(cl$truth)
estimate_name <- as.character(cl$estimate)
levels <- levels(data[[truth_name]])
# Get current default behavior
out_raw <- eval.parent(cl) |>
dplyr::filter(!.estimator %in% c("macro", "macro_weighted", "micro"))
# A function to binarize a multiclass factor
fct_binarize <- function(.f, lvl) {
forcats::fct_collapse(.f, Yes = lvl, other_level = "No")
}
# For each class...
out <- levels |>
purrr::set_names() |>
purrr::map(function(lvl) {
tmp_data <- data
tmp_cl <- cl
# binarize
tmp_data[[truth_name]] <- fct_binarize(tmp_data[[truth_name]], lvl)
tmp_data[[estimate_name]] <- fct_binarize(tmp_data[[estimate_name]], lvl)
tmp_cl$data <- tmp_data
# Get the metrics
eval.parent(tmp_cl)
}) |>
# rbind with an `.class` column for each class
dplyr::bind_rows(.id = ".class") |>
dplyr::mutate(
# New estimator name?
.estimator = "event"
) |>
# keep only metrics that use macro/micro multiclass estimators
dplyr::anti_join(out_raw, by = dplyr::join_by(.metric)) |>
# Add the (non-classwise) metrics
bind_rows(out_raw)
# Deal with grouped data frames
if (inherits(data, "grouped_df")) {
out <- out |>
dplyr::relocate(dplyr::all_of(dplyr::group_vars(data)),
.before = 1) |>
dplyr::arrange(dplyr::pick(dplyr::all_of(dplyr::group_vars(data))))
}
out
}
You may interested in new_groupwise_metric()
from yardstick 1.3.0. :) The new "grouping behavior in yardstick" vignette describes what we mean by "groupwise."
I believe we may be able to close this issue, but will let @EmilHvitfeldt decide!