yardstick icon indicating copy to clipboard operation
yardstick copied to clipboard

[Feature request or question] Provide option for metric analysis by class in a multi-class setting

Open deschen1 opened this issue 2 years ago β€’ 11 comments

Not sure if already available easily, but when having a multi-class classification problem, then it would be good to automatically get the performance on several metrics separately for each class, not just combined. Currently, metric_set just returns a combined/aggregated (?) analysis.

(Sorry for the long reprex):

iris

model_recipe <- recipes::recipe(Species ~ ., data = iris)

# Create a workflow
model_final <- parsnip::naive_Bayes(Laplace = 1) |>
  parsnip::set_mode("classification") |>
  parsnip::set_engine("klaR",
                      prior = rep(1/3, 3),
                      usekernel = FALSE)

model_final_wf <- workflows::workflow() |>
  workflows::add_recipe(model_recipe) |>
  workflows::add_model(model_final)

train_fit <- model_final_wf |>
  generics::fit(data = iris)

# Add predictions
train_predictions <- predict(train_fit, iris, type = "prob") |> 
  dplyr::mutate(class_pred = as.factor(apply(dplyr::across(tidyselect::everything()), 1, which.max))) |> 
  dplyr::bind_cols(iris)

# Check some metrics
multimetric <- yardstick::metric_set(yardstick::f_meas,
                                     yardstick::accuracy,
                                     yardstick::bal_accuracy,
                                     yardstick::sens,
                                     yardstick::spec,
                                     yardstick::precision,
                                     yardstick::recall,
                                     yardstick::ppv,
                                     yardstick::npv)

train_predictions |>
  dplyr::mutate(Species = dplyr::case_when(.data$Species == "setosa"~ 1,
                                           .data$Species == "versicolor" ~ 2,
                                           .data$Species == "virginica" ~ 3),
                Species = as.factor(.data$Species)) |> 
  multimetric(truth    = .data$Species,
              estimate = .data$class_pred)

# A tibble: 9 Γ— 3
  .metric      .estimator .estimate
  <chr>        <chr>          <dbl>
1 f_meas       macro           0.96
2 accuracy     multiclass      0.96
3 bal_accuracy macro           0.97
4 sens         macro           0.96
5 spec         macro           0.98
6 precision    macro           0.96
7 recall       macro           0.96
8 ppv          macro           0.96
9 npv          macro           0.98

So additionally to thsi analysis, it would be great to be able to see e.g. accuracy for each class. Is that already doable, am I missing an easy way of getting there, or could this be added since it's a pretty common and standard step to check class performance in a multi-class setting.

deschen1 avatar Oct 12 '22 09:10 deschen1

Hello @deschen1 πŸ‘‹

You should already be able to do this with the current version of yardstick. All the metrics works on grouped data.frames so you can call

train_predictions |>
  group_by(Species) |>
  multimetric(truth = Species, estimate = .pred_class)

To get the results you want. I also took the liberaty to show the augment() which is a great function to combine your original data with predictions. For classification models it will return predicted probabilities and classes. You will get some warnings because if you stratify your calculation by outcome, you will have some levels with no true events, given undefined results.

library(tidymodels)
library(discrim)

model_recipe <- recipe(Species ~ ., data = iris)

# Create a workflow
model_final <- naive_Bayes(Laplace = 1) |>
  set_mode("classification") |>
  set_engine("klaR", prior = rep(1/3, 3), usekernel = FALSE)

model_final_wf <- workflow() |>
  add_recipe(model_recipe) |>
  add_model(model_final)

train_fit <- fit(model_final_wf, data = iris)

# Add predictions
train_predictions <- augment(train_fit, iris)
train_predictions
#> # A tibble: 150 Γ— 9
#>    Sepal.Len…¹ Sepal…² Petal…³ Petal…⁴ Species .pred…⁡ .pred…⁢ .pred_…⁷ .pred_…⁸
#>          <dbl>   <dbl>   <dbl>   <dbl> <fct>   <fct>     <dbl>    <dbl>    <dbl>
#>  1         5.1     3.5     1.4     0.2 setosa  setosa     1    2.98e-18 2.15e-25
#>  2         4.9     3       1.4     0.2 setosa  setosa     1    3.17e-17 6.94e-25
#>  3         4.7     3.2     1.3     0.2 setosa  setosa     1    2.37e-18 7.24e-26
#>  4         4.6     3.1     1.5     0.2 setosa  setosa     1    3.07e-17 8.69e-25
#>  5         5       3.6     1.4     0.2 setosa  setosa     1    1.02e-18 8.89e-26
#>  6         5.4     3.9     1.7     0.4 setosa  setosa     1.00 2.72e-14 4.34e-21
#>  7         4.6     3.4     1.4     0.3 setosa  setosa     1    2.32e-17 7.99e-25
#>  8         5       3.4     1.5     0.2 setosa  setosa     1    1.39e-17 8.17e-25
#>  9         4.4     2.9     1.4     0.2 setosa  setosa     1    1.99e-17 3.61e-25
#> 10         4.9     3.1     1.5     0.1 setosa  setosa     1    7.38e-18 3.62e-25
#> # … with 140 more rows, and abbreviated variable names ¹​Sepal.Length,
#> #   ²​Sepal.Width, ³​Petal.Length, ⁴​Petal.Width, ⁡​.pred_class, ⁢​.pred_setosa,
#> #   ⁷​.pred_versicolor, ⁸​.pred_virginica

# Check some metrics
multimetric <- metric_set(
  yardstick::f_meas,
  yardstick::accuracy,
  yardstick::bal_accuracy,
  yardstick::sens,
  yardstick::spec,
  yardstick::precision,
  yardstick::recall,
  yardstick::ppv,
  yardstick::npv
)

train_predictions |>
  group_by(Species) |>
  multimetric(truth = Species, estimate = .pred_class)
#> # A tibble: 27 Γ— 4
#>    Species    .metric      .estimator .estimate
#>    <fct>      <chr>        <chr>          <dbl>
#>  1 setosa     f_meas       macro          1    
#>  2 versicolor f_meas       macro          0.969
#>  3 virginica  f_meas       macro          0.969
#>  4 setosa     accuracy     multiclass     1    
#>  5 versicolor accuracy     multiclass     0.94 
#>  6 virginica  accuracy     multiclass     0.94 
#>  7 setosa     bal_accuracy macro         NA    
#>  8 versicolor bal_accuracy macro         NA    
#>  9 virginica  bal_accuracy macro         NA    
#> 10 setosa     sens         macro          1    
#> # … with 17 more rows

Created on 2022-10-25 with reprex v2.0.2

EmilHvitfeldt avatar Oct 25 '22 23:10 EmilHvitfeldt

Thanks a lot @EmilHvitfeldt . One more question: how could I add the roc_auc for each class?

Simply adding yardstick::roc_auc to the multimetric set does not work, which is, I think, because it is a metric that takes the class probabilities. However, even if I do a standalone metric it doesn't work:

train_predictions |>
  roc_auc(truth = Species,
          estimate = .pred_setosa:.pred_virginica)

Adding a group_by(Species) to it gives a lot of warnings and returns NaN for each class.

and only gives me one combined area. However, when using roc_curve + autoplot I will get one curve (and hence AUC) for each level.

deschen1 avatar Oct 26 '22 07:10 deschen1

I think this is the same as https://github.com/tidymodels/yardstick/issues/4

i.e. it sounds like you want one-vs-all output (with regard to factor levels). Like:

# truth
"a" "b" "c" "c" "a"

# for a, recode as
"y" "n" "n" "n" "y"

# for b, recode as
"n" "y" "n" "n" "n"

# for c, recode as
"n" "n" "y" "y" "n"

Then do the same thing for estimate, and compute 3 separate binary accuracy() calculations for each set of recoded values. yardstick:::one_vs_all_impl() is an internal helper that does this because we need to do those computations to compute some macro/micro estimates, but the kind of output that we'd get back doesn't natively fit into the rest of our API so I decided not to add it.

It is possible it could be a separate helper function that just wouldn't be usable in tuning and couldn't be combined in a metric set. It could maybe look something like:

one_vs_all <- function(data, truth, estimate, fn) {
 # impl
}

one_vs_all(data, Species, pred, metric_set(accuracy, precision))
#> # A tibble: 6 Γ— 4
#>   .level     .metric   .estimator .estimate
#>   <chr>      <chr>     <chr>          <dbl>
#> 1 setosa     accuracy  binary           0.7
#> 2 setosa     precision binary           0.4
#> 3 versicolor accuracy  binary           0.8
#> 4 versicolor precision binary           0.3
#> 5 virginica  accuracy  binary           0.2
#> 6 virginica  precision binary           0.5

We'd have to consider how the API would look for numeric/class/class-prob metrics, because you supply estimate for some and ... for class-prob metrics.

DavisVaughan avatar Oct 27 '22 20:10 DavisVaughan

Thanks Davis!

Not entirely sure if what you are saying is that I can get what I want by using these already implemented (?) functions or if this is something you are going to add to the package?

deschen1 avatar Oct 27 '22 21:10 deschen1

You cannot currently do it.

I am not sure if we should expose these utilities or not, but it has come up a few times so it might be worth it to expose them in a limited way.

DavisVaughan avatar Oct 28 '22 18:10 DavisVaughan

Got it. I think it would generally be a great idea to offer class-specific metrics in a multi-class setting. From what I've seen in other tools and for certain analyses, it is a common but also very important task.

E.g. in a market research study where you segment/cluster customers, there might be a very important class that you want to predict very well even if that means you are missing out on some other classes. But you could only judge by seeing the class-specific metrics.

deschen1 avatar Oct 28 '22 18:10 deschen1

Hello @deschen1 πŸ‘‹

You should already be able to do this with the current version of yardstick. All the metrics works on grouped data.frames so you can call

train_predictions |>
  group_by(Species) |>
  multimetric(truth = Species, estimate = .pred_class)

To get the results you want. I also took the liberaty to show the augment() which is a great function to combine your original data with predictions. For classification models it will return predicted probabilities and classes. You will get some warnings because if you stratify your calculation by outcome, you will have some levels with no true events, given undefined results.

library(tidymodels)
library(discrim)

model_recipe <- recipe(Species ~ ., data = iris)

# Create a workflow
model_final <- naive_Bayes(Laplace = 1) |>
  set_mode("classification") |>
  set_engine("klaR", prior = rep(1/3, 3), usekernel = FALSE)

model_final_wf <- workflow() |>
  add_recipe(model_recipe) |>
  add_model(model_final)

train_fit <- fit(model_final_wf, data = iris)

# Add predictions
train_predictions <- augment(train_fit, iris)
train_predictions
#> # A tibble: 150 Γ— 9
#>    Sepal.Len…¹ Sepal…² Petal…³ Petal…⁴ Species .pred…⁡ .pred…⁢ .pred_…⁷ .pred_…⁸
#>          <dbl>   <dbl>   <dbl>   <dbl> <fct>   <fct>     <dbl>    <dbl>    <dbl>
#>  1         5.1     3.5     1.4     0.2 setosa  setosa     1    2.98e-18 2.15e-25
#>  2         4.9     3       1.4     0.2 setosa  setosa     1    3.17e-17 6.94e-25
#>  3         4.7     3.2     1.3     0.2 setosa  setosa     1    2.37e-18 7.24e-26
#>  4         4.6     3.1     1.5     0.2 setosa  setosa     1    3.07e-17 8.69e-25
#>  5         5       3.6     1.4     0.2 setosa  setosa     1    1.02e-18 8.89e-26
#>  6         5.4     3.9     1.7     0.4 setosa  setosa     1.00 2.72e-14 4.34e-21
#>  7         4.6     3.4     1.4     0.3 setosa  setosa     1    2.32e-17 7.99e-25
#>  8         5       3.4     1.5     0.2 setosa  setosa     1    1.39e-17 8.17e-25
#>  9         4.4     2.9     1.4     0.2 setosa  setosa     1    1.99e-17 3.61e-25
#> 10         4.9     3.1     1.5     0.1 setosa  setosa     1    7.38e-18 3.62e-25
#> # … with 140 more rows, and abbreviated variable names ¹​Sepal.Length,
#> #   ²​Sepal.Width, ³​Petal.Length, ⁴​Petal.Width, ⁡​.pred_class, ⁢​.pred_setosa,
#> #   ⁷​.pred_versicolor, ⁸​.pred_virginica

# Check some metrics
multimetric <- metric_set(
  yardstick::f_meas,
  yardstick::accuracy,
  yardstick::bal_accuracy,
  yardstick::sens,
  yardstick::spec,
  yardstick::precision,
  yardstick::recall,
  yardstick::ppv,
  yardstick::npv
)

train_predictions |>
  group_by(Species) |>
  multimetric(truth = Species, estimate = .pred_class)
#> # A tibble: 27 Γ— 4
#>    Species    .metric      .estimator .estimate
#>    <fct>      <chr>        <chr>          <dbl>
#>  1 setosa     f_meas       macro          1    
#>  2 versicolor f_meas       macro          0.969
#>  3 virginica  f_meas       macro          0.969
#>  4 setosa     accuracy     multiclass     1    
#>  5 versicolor accuracy     multiclass     0.94 
#>  6 virginica  accuracy     multiclass     0.94 
#>  7 setosa     bal_accuracy macro         NA    
#>  8 versicolor bal_accuracy macro         NA    
#>  9 virginica  bal_accuracy macro         NA    
#> 10 setosa     sens         macro          1    
#> # … with 17 more rows

Created on 2022-10-25 with reprex v2.0.2

Hi @EmilHvitfeldt, is it possible with this solution to also get the standard error of the metrics? Basically, is it possible to get the same style of output that collect_metrics provides, but per class?

lardenoije avatar Aug 23 '23 14:08 lardenoije

Hi @EmilHvitfeldt , @DavisVaughan

Following up the above information on category-wise metrics, I've found a potential issue when using the estimator = "macro_weighted".


I've used a slightly different data set:

library(tidymodels)
library(palmerpenguins)

conflicts_prefer(palmerpenguins::penguins)

penguins_split <- initial_split(penguins, strata = "species")
penguins_train <- training(penguins_split)

rf_spec <- rand_forest() |>
	set_mode("classification")

results <- 
	workflow(preprocessor = recipe(species ~ island + year, data = penguins_train),
					 spec         =rf_spec  |>
	last_fit(penguins_split)


results

# A tibble: 86 Γ— 8
   id               .pred_Adelie .pred_Chinstrap .pred_Gentoo  .row .pred_class species .config             
   <chr>                   <dbl>           <dbl>        <dbl> <int> <fct>       <fct>   <chr>               
 1 train/test split        0.788           0.152       0.0599     3 Adelie      Adelie  Preprocessor1_Model1
 2 train/test split        0.788           0.152       0.0599     9 Adelie      Adelie  Preprocessor1_Model1
 3 train/test split        0.788           0.152       0.0599    15 Adelie      Adelie  Preprocessor1_Model1
 4 train/test split        0.788           0.152       0.0599    19 Adelie      Adelie  Preprocessor1_Model1
 5 train/test split        0.457           0.484       0.0599    31 Chinstrap   Adelie  Preprocessor1_Model1
 6 train/test split        0.457           0.484       0.0599    32 Chinstrap   Adelie  Preprocessor1_Model1
 7 train/test split        0.457           0.484       0.0599    36 Chinstrap   Adelie  Preprocessor1_Model1
 8 train/test split        0.457           0.484       0.0599    40 Chinstrap   Adelie  Preprocessor1_Model1
 9 train/test split        0.457           0.484       0.0599    43 Chinstrap   Adelie  Preprocessor1_Model1
10 train/test split        0.457           0.484       0.0599    49 Chinstrap   Adelie  Preprocessor1_Model1
# ? 76 more rows

Using estimator = "macro":

results |>
	collect_predictions() |>
	precision(
		species ,
		.pred_class,
		estimator = "macro"
	)

# A tibble: 1 Γ— 3
  .metric   .estimator .estimate
  <chr>     <chr>          <dbl>
1 precision macro          0.603
results |>
	collect_predictions() |>
	group_by(.pred_class) |>
	precision(
		species ,
		.pred_class,
		estimator = "macro"
	)

# A tibble: 3 Γ— 4
  .pred_class .metric   .estimator .estimate
  <fct>       <chr>     <chr>          <dbl>
1 Adelie      precision macro          0.636
2 Chinstrap   precision macro          0.417
3 Gentoo      precision macro          0.756

Using the estimator = "macro_weighted":

results |>
	collect_predictions() |>
	precision(
		species ,
		.pred_class,
		estimator = "macro_weighted"
	)
# A tibble: 1 Γ— 3
  .metric   .estimator     .estimate
  <chr>     <chr>              <dbl>
1 precision macro_weighted     0.636
results |>
	collect_predictions() |>
	group_by(.pred_class) |>
	precision(
		species ,
		.pred_class,
		estimator = "macro_weighted"
	)

# A tibble: 3 Γ— 4
  .pred_class .metric   .estimator     .estimate
  <fct>       <chr>     <chr>              <dbl>
1 Adelie      precision macro_weighted     0.636
2 Chinstrap   precision macro_weighted     0.417
3 Gentoo      precision macro_weighted     0.756

The values on a class-level for "macro_weighted" are identical to "macro".

On an overall-level its seems to be correct.

viv-analytics avatar Sep 26 '23 09:09 viv-analytics

I implemented a function to calculate the one-vs-all per-class metrics:

library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
#> βœ” broom        1.0.5      βœ” recipes      1.0.10
#> βœ” dials        1.2.1      βœ” rsample      1.2.0 
#> βœ” dplyr        1.1.4      βœ” tibble       3.2.1 
#> βœ” ggplot2      3.5.0      βœ” tidyr        1.3.1 
#> βœ” infer        1.0.6      βœ” tune         1.1.2 
#> βœ” modeldata    1.3.0      βœ” workflows    1.1.4 
#> βœ” parsnip      1.2.0      βœ” workflowsets 1.0.1 
#> βœ” purrr        1.0.2      βœ” yardstick    1.3.0
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> βœ– purrr::discard() masks scales::discard()
#> βœ– dplyr::filter()  masks stats::filter()
#> βœ– dplyr::lag()     masks stats::lag()
#> βœ– recipes::step()  masks stats::step()
#> β€’ Search for functions across packages at https://www.tidymodels.org/find/
library(ranger)

ova_metrics <- function(x, truth, estimate, metric_set) {
  x %>% 
    dplyr::mutate(
      truth_ova = purrr::map({{ truth }}, ~ {
        case_when(
          levels({{ truth }}) %in% .x ~ as.character(.x),
          is.na(.x) ~ NA_character_,
          .default = "class_0"
        ) %>% 
          rlang::set_names(levels({{ truth }}))
      }),
      estimate_ova = purrr::map({{ estimate }}, ~ {
        case_when(
          levels({{ estimate }}) %in% .x ~ as.character(.x),
          is.na(.x) ~ NA_character_,
          .default = "class_0"
        ) %>% 
          rlang::set_names(levels({{ estimate }}))
      })
    ) %>%
    tidyr::unnest_longer(col = c(truth_ova, estimate_ova)) %>% 
    dplyr::mutate(class_group = purrr::map2_chr(truth_ova_id, estimate_ova_id, unique)) %>% 
    tidyr::nest(.by = class_group) %>%
    dplyr::mutate(
      data = data %>%
        purrr::map(~ dplyr::mutate(.x, dplyr::across(
          dplyr::matches("ova"),
          ~ factor(.x) %>%
            forcats::fct_expand("class_0") %>%
            forcats::fct_relevel("class_0", after = Inf)
        ))) %>%
        purrr::map(metric_set, truth = truth_ova, estimate = estimate_ova) %>%
        suppressWarnings()
    ) %>%
    tidyr::unnest(cols = data)
}

model_recipe <- recipe(Species ~ ., data = iris)

model_final <- rand_forest(
  mode = "classification",
  engine = "ranger",
  mtry = 3,
  min_n = 20,
  trees = 500
) %>%
  set_engine("ranger", importance = "impurity")

model_final_wf <- workflow() |>
  add_recipe(model_recipe) |>
  add_model(model_final)

set.seed(2024)
train_fit <- fit(model_final_wf, data = iris)

train_predictions <- augment(train_fit, iris)
train_predictions %>% conf_mat(Species, .pred_class)
#>             Truth
#> Prediction   setosa versicolor virginica
#>   setosa         50          0         0
#>   versicolor      0         47         1
#>   virginica       0          3        49

per_class_mset <-
  metric_set(accuracy, sensitivity, specificity, f_meas, bal_accuracy, kap)

train_predictions %>% 
  per_class_mset(truth = Species, estimate = .pred_class)
#> # A tibble: 6 Γ— 3
#>   .metric      .estimator .estimate
#>   <chr>        <chr>          <dbl>
#> 1 accuracy     multiclass     0.973
#> 2 sensitivity  macro          0.973
#> 3 specificity  macro          0.987
#> 4 f_meas       macro          0.973
#> 5 bal_accuracy macro          0.98 
#> 6 kap          multiclass     0.96

train_predictions %>% 
  ova_metrics(truth = Species, estimate = .pred_class, metric_set = per_class_mset)
#> # A tibble: 18 Γ— 4
#>    class_group .metric      .estimator .estimate
#>    <chr>       <chr>        <chr>          <dbl>
#>  1 setosa      accuracy     binary         1    
#>  2 setosa      sensitivity  binary         1    
#>  3 setosa      specificity  binary         1    
#>  4 setosa      f_meas       binary         1    
#>  5 setosa      bal_accuracy binary         1    
#>  6 setosa      kap          binary         1    
#>  7 versicolor  accuracy     binary         0.973
#>  8 versicolor  sensitivity  binary         0.94 
#>  9 versicolor  specificity  binary         0.99 
#> 10 versicolor  f_meas       binary         0.959
#> 11 versicolor  bal_accuracy binary         0.965
#> 12 versicolor  kap          binary         0.939
#> 13 virginica   accuracy     binary         0.973
#> 14 virginica   sensitivity  binary         0.98 
#> 15 virginica   specificity  binary         0.97 
#> 16 virginica   f_meas       binary         0.961
#> 17 virginica   bal_accuracy binary         0.975
#> 18 virginica   kap          binary         0.941

Created on 2024-04-02 by the reprex package (v0.3.0)

dchiu911 avatar Apr 03 '24 00:04 dchiu911

I've also come up with a solution:

  • [x] works with grouped data
  • [x] keeps metrics with multiclass support as-is

library(tidymodels)
library(ranger)

model_recipe <- recipe(Species ~ ., data = iris)

model_final <- rand_forest(
  mode = "classification",
  engine = "ranger",
  mtry = 3,
  min_n = 20,
  trees = 500
) %>%
  set_engine("ranger", importance = "impurity")

model_final_wf <- workflow() |>
  add_recipe(model_recipe) |>
  add_model(model_final)

set.seed(2024)
train_fit <- fit(model_final_wf, data = iris)

train_predictions <- augment(train_fit, iris)
train_predictions %>% conf_mat(Species, .pred_class)
#>             Truth
#> Prediction   setosa versicolor virginica
#>   setosa         50          0         0
#>   versicolor      0         47         2
#>   virginica       0          3        48

per_class_mset <-
  metric_set(accuracy, sensitivity, specificity, f_meas, bal_accuracy, kap)

train_predictions |> 
  metric_by_event_by(per_class_mset, truth = Species, estimate = .pred_class)
#> # A tibble: 14 Γ— 4
#>    .class     .metric      .estimator .estimate
#>    <chr>      <chr>        <chr>          <dbl>
#>  1 setosa     sensitivity  event          1    
#>  2 setosa     specificity  event          1    
#>  3 setosa     f_meas       event          1    
#>  4 setosa     bal_accuracy event          1    
#>  5 versicolor sensitivity  event          0.94 
#>  6 versicolor specificity  event          0.98 
#>  7 versicolor f_meas       event          0.949
#>  8 versicolor bal_accuracy event          0.96 
#>  9 virginica  sensitivity  event          0.96 
#> 10 virginica  specificity  event          0.97 
#> 11 virginica  f_meas       event          0.950
#> 12 virginica  bal_accuracy event          0.965
#> 13 NA         accuracy     multiclass     0.967
#> 14 NA         kap          multiclass     0.95

The function:

metric_by_event_by <- function(data, metric_set, truth, estimate, ...) {
  # Only for class_prob_metric_set
  stopifnot(inherits(metric_set , "class_prob_metric_set"))
  
  # Make a call with the metric set
  cl <- match.call()
  cl[[1]] <- as.name(as.character(cl$metric_set))
  cl$metric_set <- NULL
  cl$event_level <- "first"
  
  # Find names and levels
  truth_name <- as.character(cl$truth)
  estimate_name <- as.character(cl$estimate)
  levels <- levels(data[[truth_name]])
  
  # Get current default behavior
  out_raw <- eval.parent(cl) |> 
    dplyr::filter(!.estimator %in% c("macro", "macro_weighted", "micro"))
  
  
  # A function to binarize a multiclass factor
  fct_binarize <- function(.f, lvl) {
    forcats::fct_collapse(.f, Yes = lvl, other_level = "No")  
  }
  
  # For each class...
  out <- levels |> 
    purrr::set_names() |> 
    purrr::map(function(lvl) {
      tmp_data <- data
      tmp_cl <- cl
      
      # binarize
      tmp_data[[truth_name]] <- fct_binarize(tmp_data[[truth_name]], lvl)
      tmp_data[[estimate_name]] <- fct_binarize(tmp_data[[estimate_name]], lvl)
      tmp_cl$data <- tmp_data
      
      # Get the metrics
      eval.parent(tmp_cl)
    }) |> 
    # rbind with an `.class` column for each class
    dplyr::bind_rows(.id = ".class") |> 
    dplyr::mutate(
      # New estimator name?
      .estimator = "event"
    ) |> 
    # keep only metrics that use macro/micro multiclass estimators
    dplyr::anti_join(out_raw, by = dplyr::join_by(.metric)) |> 
    # Add the (non-classwise) metrics
    bind_rows(out_raw)
  
  # Deal with grouped data frames
  if (inherits(data, "grouped_df")) {
    out <- out |> 
      dplyr::relocate(dplyr::all_of(dplyr::group_vars(data)), 
                      .before = 1) |> 
      dplyr::arrange(dplyr::pick(dplyr::all_of(dplyr::group_vars(data))))
  }
  
  out
}

mattansb avatar Aug 03 '24 06:08 mattansb

You may interested in new_groupwise_metric() from yardstick 1.3.0. :) The new "grouping behavior in yardstick" vignette describes what we mean by "groupwise."

I believe we may be able to close this issue, but will let @EmilHvitfeldt decide!

simonpcouch avatar Aug 03 '24 21:08 simonpcouch