tune icon indicating copy to clipboard operation
tune copied to clipboard

`collect_metrics` ignores case weights in calculating `concordance_survival`

Open lilykoff opened this issue 1 year ago • 5 comments

The problem

When fitting models with weights in tidymodels, the collect_metrics does not take into account weights when calculating certain metrics, including concordance_survival. The concordance_survival function has an argument for case weights, but when collect_metrics is called on an object from fit_resamples, the case weights aren't used. I'm not sure if this is more of a feature request than a bug report, but I think it should be more clear in the collect_metrics documentation that case weights will not be used in calculating certain metrics.

Example dataset

This data is from 50-79 year olds in NHANES 2011-2014. To make it simple, I just fit models to predict time to mortality from one predictor: age in years. The weights are the normalized survey weights provided by NHANES. The dataset is here: nhanes_mortality.csv.gz

Reproducible example

library(tidyverse)
library(tidymodels)
library(censored)
#> Loading required package: survival
# read in example dataset
nhanes_mortality = readr::read_csv("https://github.com/tidymodels/tune/files/15015680/nhanes_mortality.csv.gz")
#> Rows: 3728 Columns: 4
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> dbl (4): age, mortstat, event_time, weight_norm
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
# event time: event_time, mortality (y/n) is mortstat, age is age, weight_norm is survey weight we want to use

# create a survival object
all_mort_surv =
  nhanes_mortality %>%
  mutate(mort_surv = Surv(event_time, mortstat)) %>%
  mutate(case_weights_imp = hardhat::importance_weights(weight_norm))

# if we are not using cross validation, we can fit the model and
# ensure that the concordance estimates from survival::coxph and tidymodels are the same

# fit the model using cox regression on entire dataset
fit = coxph(mort_surv ~ age, data=all_mort_surv, weights = weight_norm)
# calculate concordance
coxph_c = unname(fit$concordance[6]); coxph_c
#> [1] 0.6707893

# now use tidymodels framework:
survival_metrics = metric_set(concordance_survival)
survreg_spec =
  proportional_hazards() %>%
  set_engine("survival") %>%
  set_mode("censored regression")

model_fit =
  survreg_spec %>%
  fit(mort_surv ~ age, data = all_mort_surv, case_weights = all_mort_surv$case_weights_imp)
# predict on same dataset
augmented = all_mort_surv %>%
  mutate(pred = predict(model_fit, new_data = all_mort_surv, type = "time")) %>%
  unnest(pred)

tidy_c = concordance_survival(data = augmented, truth = mort_surv, estimate = ".pred_time",
                               case_weights = case_weights_imp)$.estimate; tidy_c
#> [1] 0.6707893
tidy_c; coxph_c
#> [1] 0.6707893
#> [1] 0.6707893
all.equal(tidy_c, coxph_c)
#> [1] TRUE


# now we want to do cross validation
set.seed(123)
# get five folds
folds = vfold_cv(all_mort_surv, v = 5, repeats = 1)
get_conc = function(ind, folds){
  # get concordance using coxph
  # for each fold, get the training and testing data, calculate concordance
  dat = get_rsplit(folds, index = ind)
  fit = coxph(mort_surv ~ age, data=analysis(dat), weights = weight_norm)
  test = assessment(dat)
  concordance(mort_surv ~ predict(fit, newdata = test), data = test, weights = weight_norm, reverse = TRUE)[[1]]
}
coxph_c_cv = map_dbl(.x = 1:5, .f = get_conc, folds = folds)
coxph_c_cv
#> [1] 0.6460921 0.6106318 0.7961755 0.6727030 0.6150353
mean(coxph_c_cv)
#> [1] 0.6681275

# now we do the same in the tidy framework
wflow =
  workflow() %>%
  add_model(survreg_spec) %>%
  add_variables(outcomes = mort_surv,
                predictors = age) %>%
  add_case_weights(case_weights_imp)

res = fit_resamples(
  wflow,
  resamples = folds,
  metrics = survival_metrics,
  control = control_resamples(save_pred = TRUE)
)

tidy_c_cv = collect_metrics(res)
all.equal(mean(coxph_c_cv), tidy_c_cv$mean) # they aren't the same
#> [1] "Mean relative difference: 0.002871308"
# this is because collect metrics isn't calculating a weighted concordance

# if we "manually" get concordance for each fold, it agrees with the coxph way
get_concordance = function(row_num, df){
  df %>%
    slice(row_num) %>%
    unnest(.predictions) %>%
    select(.pred_time, .row, mort_surv) %>%
    left_join(all_mort_surv %>% mutate(row_ind = row_number()) %>% select(row_ind, case_weights_imp), by = c(".row" = "row_ind")) %>%
    concordance_survival(truth = mort_surv, estimate = ".pred_time", case_weights = case_weights_imp) %>%
    pull(.estimate)
}

tidy_c_cv_manual = map_dbl(.x = 1:nrow(res), .f = get_concordance, df = res)
all.equal(mean(coxph_c_cv), mean(tidy_c_cv_manual))
#> [1] TRUE

Session Info

sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22)
#>  os       macOS Big Sur/Monterey 10.16
#>  system   x86_64, darwin17.0
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       America/New_York
#>  date     2024-04-17
#>  pandoc   3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>  bit            4.0.5      2022-11-15 [1] CRAN (R 4.2.0)
#>  bit64          4.0.5      2020-08-30 [1] CRAN (R 4.2.0)
#>  broom        * 1.0.5      2023-06-09 [1] CRAN (R 4.2.0)
#>  censored     * 0.3.0      2024-01-31 [1] CRAN (R 4.2.0)
#>  class          7.3-22     2023-05-03 [1] CRAN (R 4.2.0)
#>  cli            3.6.2      2023-12-11 [1] CRAN (R 4.2.0)
#>  codetools      0.2-19     2023-02-01 [1] CRAN (R 4.2.0)
#>  colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.2.0)
#>  crayon         1.5.2      2022-09-29 [1] CRAN (R 4.2.0)
#>  data.table     1.14.8     2023-02-17 [1] CRAN (R 4.2.0)
#>  dials        * 1.2.1      2024-02-22 [1] CRAN (R 4.2.0)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.2.0)
#>  digest         0.6.34     2024-01-11 [1] CRAN (R 4.2.0)
#>  dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.2.0)
#>  evaluate       0.23       2023-11-01 [1] CRAN (R 4.2.0)
#>  fansi          1.0.6      2023-12-08 [1] CRAN (R 4.2.0)
#>  fastmap        1.1.1      2023-02-24 [1] CRAN (R 4.2.0)
#>  forcats      * 1.0.0      2023-01-29 [1] CRAN (R 4.2.0)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.2.0)
#>  fs             1.6.3      2023-07-20 [1] CRAN (R 4.2.0)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.2.0)
#>  future         1.33.0     2023-07-01 [1] CRAN (R 4.2.0)
#>  future.apply   1.11.0     2023-05-21 [1] CRAN (R 4.2.0)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.2.0)
#>  ggplot2      * 3.5.0      2024-02-23 [1] CRAN (R 4.2.0)
#>  globals        0.16.2     2022-11-21 [1] CRAN (R 4.2.0)
#>  glue           1.7.0      2024-01-09 [1] CRAN (R 4.2.0)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.2.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.2.0)
#>  gtable         0.3.4      2023-08-21 [1] CRAN (R 4.2.0)
#>  hardhat        1.3.1      2024-02-02 [1] CRAN (R 4.2.0)
#>  hms            1.1.3      2023-03-21 [1] CRAN (R 4.2.0)
#>  htmltools      0.5.7      2023-11-03 [1] CRAN (R 4.2.0)
#>  infer        * 1.0.7      2024-03-25 [1] CRAN (R 4.2.0)
#>  ipred          0.9-14     2023-03-09 [1] CRAN (R 4.2.0)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.2.0)
#>  knitr          1.45       2023-10-30 [1] CRAN (R 4.2.0)
#>  lattice        0.21-8     2023-04-05 [1] CRAN (R 4.2.0)
#>  lava           1.7.2.1    2023-02-27 [1] CRAN (R 4.2.0)
#>  lhs            1.1.6      2022-12-17 [1] CRAN (R 4.2.0)
#>  lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.2.0)
#>  listenv        0.9.0      2022-12-16 [1] CRAN (R 4.2.0)
#>  lubridate    * 1.9.3      2023-09-27 [1] CRAN (R 4.2.0)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>  MASS           7.3-60     2023-05-04 [1] CRAN (R 4.2.0)
#>  Matrix         1.5-4.1    2023-05-18 [1] CRAN (R 4.2.0)
#>  modeldata    * 1.3.0      2024-01-21 [1] CRAN (R 4.2.0)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.2.0)
#>  nnet           7.3-19     2023-05-03 [1] CRAN (R 4.2.0)
#>  parallelly     1.36.0     2023-05-26 [1] CRAN (R 4.2.0)
#>  parsnip      * 1.2.1      2024-03-22 [1] CRAN (R 4.2.0)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.2.0)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>  prodlim        2023.03.31 2023-04-02 [1] CRAN (R 4.2.0)
#>  purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.2.0)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>  Rcpp           1.0.12     2024-01-09 [1] CRAN (R 4.2.0)
#>  readr        * 2.1.5      2024-01-10 [1] CRAN (R 4.2.0)
#>  recipes      * 1.0.10     2024-02-18 [1] CRAN (R 4.2.0)
#>  reprex         2.0.2      2022-08-17 [1] CRAN (R 4.2.0)
#>  rlang          1.1.3      2024-01-10 [1] CRAN (R 4.2.0)
#>  rmarkdown      2.25       2023-09-18 [1] CRAN (R 4.2.0)
#>  rpart          4.1.19     2022-10-21 [1] CRAN (R 4.2.0)
#>  rsample      * 1.2.1      2024-03-25 [1] CRAN (R 4.2.0)
#>  rstudioapi     0.16.0     2024-03-24 [1] CRAN (R 4.2.0)
#>  scales       * 1.3.0      2023-11-28 [1] CRAN (R 4.2.0)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>  stringi        1.8.3      2023-12-11 [1] CRAN (R 4.2.0)
#>  stringr      * 1.5.1      2023-11-14 [1] CRAN (R 4.2.0)
#>  survival     * 3.5-5      2023-03-12 [1] CRAN (R 4.2.0)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.2.0)
#>  tidymodels   * 1.2.0      2024-03-25 [1] CRAN (R 4.2.0)
#>  tidyr        * 1.3.1      2024-01-24 [1] CRAN (R 4.2.0)
#>  tidyselect     1.2.0      2022-10-10 [1] CRAN (R 4.2.0)
#>  tidyverse    * 2.0.0      2023-02-22 [1] CRAN (R 4.2.0)
#>  timechange     0.3.0      2024-01-18 [1] CRAN (R 4.2.0)
#>  timeDate       4022.108   2023-01-07 [1] CRAN (R 4.2.0)
#>  tune         * 1.2.0      2024-03-20 [1] CRAN (R 4.2.0)
#>  tzdb           0.4.0      2023-05-12 [1] CRAN (R 4.2.0)
#>  utf8           1.2.4      2023-10-22 [1] CRAN (R 4.2.0)
#>  vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.2.0)
#>  vroom          1.6.5      2023-12-05 [1] CRAN (R 4.2.0)
#>  withr          3.0.0      2024-01-16 [1] CRAN (R 4.2.0)
#>  workflows    * 1.1.4      2024-02-19 [1] CRAN (R 4.2.0)
#>  workflowsets * 1.1.0      2024-03-21 [1] CRAN (R 4.2.0)
#>  xfun           0.41       2023-11-01 [1] CRAN (R 4.2.0)
#>  yaml           2.3.8      2023-12-11 [1] CRAN (R 4.2.0)
#>  yardstick    * 1.3.1      2024-03-21 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

lilykoff avatar Apr 17 '24 18:04 lilykoff

Thank you for the thorough issue description!

Noting that I've edited the call to read_csv() in the OP to read directly from GH.

simonpcouch avatar Apr 17 '24 18:04 simonpcouch

Ah, it appears this may be intentional. In ?.use_case_weights_with_yardstick():

https://github.com/tidymodels/tune/blob/668eec2b747443d7b374a0fb79e38af92ed46960/R/case_weights.R#L3-L8

That is, internally, tune intentionally doesn't use importance case weights when calculating metrics. That's about as far as I can get with GitHub archeology. This tidyverse.org blog post reads:

As a counter example, importance weights reflect the idea that they should only influence the model fitting procedure. It wouldn’t make sense to use a weighted mean to center a predictor; the weight shouldn’t influence an unsupervised operation in the same way as model estimation. More critically, any holdout data set used to quantify model efficacy should reflect the data as seen in the wild (without the impact of the weights).

At the least, we can try and find a more public-facing place in the documentation to surface this and provide a reference or two.

simonpcouch avatar Apr 17 '24 19:04 simonpcouch

So the end result is that you cannot do weighted survival concordance with weighting except for frequency weights (e.g. no survey weighting)?

muschellij2 avatar Apr 17 '24 19:04 muschellij2

Mm, yes and no. tune contends that the most safe/reasonable default for importance weights is to apply them when fitting but not calculating metrics, so it doesn't allow overriding when resampling with tune. At the same time, tune does respect the output of use_case_weights_with_yardstick() for the given case weights class. So, if you'd like to inherit all of the behavior for importance weights except that you'd like to apply them at metric calculation time, you could make a subclass of importance_weights and define a use_case_weights_with_yardstick() method for it. This example adapts @lilykoff's code to do so:

library(tidyverse)
library(tidymodels)
library(censored)
#> Loading required package: survival

nhanes_mortality = readr::read_csv("https://github.com/tidymodels/tune/files/15015680/nhanes_mortality.csv.gz")
#> Rows: 3728 Columns: 4
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> dbl (4): age, mortstat, event_time, weight_norm
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

# create a survival object
all_mort_surv =
  nhanes_mortality %>%
  mutate(mort_surv = Surv(event_time, mortstat)) %>%
  mutate(case_weights_imp = hardhat::importance_weights(weight_norm))

# start: simon's edits ------------------------------------------------------------------------
# subclass the case weights and introduce needed methods for them
class(all_mort_surv$case_weights_imp)
#> [1] "hardhat_importance_weights" "hardhat_case_weights"      
#> [3] "vctrs_vctr"

all_mort_surv$case_weights_imp <-
  structure(
    all_mort_surv$case_weights_imp, 
    class = c("survey_weights", class(all_mort_surv$case_weights_imp))
  )

class(all_mort_surv$case_weights_imp)
#> [1] "survey_weights"             "hardhat_importance_weights"
#> [3] "hardhat_case_weights"       "vctrs_vctr"

# define a method that tells tune to pass case weights to metrics calculations
.use_case_weights_with_yardstick.survey_weights <- function(x) {TRUE}

# define a method that tells vctrs how to convert to double
vec_cast.double.survey_weights <- hardhat:::vec_cast.double.hardhat_importance_weights

# end: simon's edits ------------------------------------------------------------------------

# now we want to do cross validation
set.seed(123)
# get five folds
folds = vfold_cv(all_mort_surv, v = 5, repeats = 1)
get_conc = function(ind, folds){
  # get concordance using coxph
  # for each fold, get the training and testing data, calculate concordance
  dat = get_rsplit(folds, index = ind)
  fit = coxph(mort_surv ~ age, data=analysis(dat), weights = weight_norm)
  test = assessment(dat)
  concordance(mort_surv ~ predict(fit, newdata = test), data = test, weights = weight_norm, reverse = TRUE)[[1]]
}
coxph_c_cv = map_dbl(.x = 1:5, .f = get_conc, folds = folds)
coxph_c_cv
#> [1] 0.6460921 0.6106318 0.7961755 0.6727030 0.6150353

mean(coxph_c_cv)
#> [1] 0.6681275


# now we do the same in the tidy framework
survival_metrics = metric_set(concordance_survival)
survreg_spec =
  proportional_hazards() %>%
  set_engine("survival") %>%
  set_mode("censored regression")


wflow =
  workflow() %>%
  add_model(survreg_spec) %>%
  add_variables(outcomes = mort_surv,
                predictors = age) %>%
  add_case_weights(case_weights_imp)

res = fit_resamples(
  wflow,
  resamples = folds,
  metrics = survival_metrics,
  control = control_resamples(save_pred = TRUE)
)

tidy_c_cv = collect_metrics(res)
all.equal(mean(coxph_c_cv), tidy_c_cv$mean)
#> [1] TRUE

Created on 2024-04-17 with reprex v2.1.0

If yall think it's an unsafe default to not apply importance weights when calculating metrics, we'd definitely be interested in hearing why and in what contexts. This may be a matter of tidymodels introducing an importance weight subclass that does default to applying case weights when calculating metrics. I didn't implement the use_case_weights_with_yardstick() methods so I'm not sure the references that those decisions draw from (and will follow up with those later), but it looks to me like they very intentionally choose not to include importance weights in metric calculations when resampling.

simonpcouch avatar Apr 17 '24 20:04 simonpcouch

So the end result is that you cannot do weighted survival concordance with weighting except for frequency weights (e.g. no survey weighting)?

We think that our logic is pretty sound for the cases that we've outlined in the blog post (referenced above). We did ask in several venues (including that blog) for help on survey weights, so we'd love to hear what you think should happen. We don't have much experience in that area.

It's not too hard to define your own type of case weight. The current definitions are here. We can help with it and, if needed , update the logic for your cases.

topepo avatar Apr 17 '24 21:04 topepo

Going to go ahead and close. Feel free to ping here or in a separate issue if you think this default ought to be changed, as we're very much interested in making improvements here.

simonpcouch avatar May 23 '24 21:05 simonpcouch

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

github-actions[bot] avatar Jun 07 '24 00:06 github-actions[bot]