yardstick icon indicating copy to clipboard operation
yardstick copied to clipboard

Allow `matrix` class for `estimate` argument of custom numeric metric

Open brshallo opened this issue 2 years ago • 2 comments

Documentation for metric_vec_template() suggests matrix input should be fine. See function documentation:

For matrices, it is best to supply "numeric" as the class to check here.

However the code for yardstick:::validate_truth_estimate_types.numeric() looks like it errors in the case of a matrix being inputted. See if statement from the source code:

if (is.matrix(estimate)) {
        abort(paste0("`estimate` should be a numeric vector, not a numeric matrix."))
}

My guess is that perhaps a matrix input is only allowed in the case of prob metrics (as have examples of matrix input e.g. for mn_log_loss), though not sure...

Context + Example

I am trying to build a custom metric for "coverage" which would take multiple inputs in for the estimate for the "lower" and "upper" prediction intervals (see #220). I tried using the source code in mn_log_loss as a template.

Example (attempting to write custom metric for coverage())

(Though at this point I didn't worry about passing through case_weights or whether to sum as figured these shouldn't matter as to whether it would error / pass validation checks.)

library(tidymodels)
library(rlang)

coverage <- function(data, ...) {
  UseMethod("coverage")
}
coverage <- new_numeric_metric(
  coverage,
  direction = "maximize"
)

#' @export
#' @rdname coverage
coverage.data.frame <- function(data,
                                   truth,
                                   ...,
                                   na_rm = TRUE,
                                   sum = FALSE,
                                   # event_level = yardstick_event_level(),
                                   case_weights = NULL) {
  estimate <- dots_to_estimate(data, !!! enquos(...))
  
  metric_summarizer(
    metric_nm = "coverage",
    metric_fn = coverage_vec,
    data = data,
    truth = !!enquo(truth),
    estimate = !!estimate,
    na_rm = na_rm,
    # event_level = event_level,
    case_weights = !!enquo(case_weights)
    # Extra argument for mn_log_loss_impl()
    # metric_fn_options = list(sum = sum)
  )
}

#' @export
#' @rdname mn_log_loss
coverage.data.frame <- function(data,
                                   truth,
                                   ...,
                                   na_rm = TRUE,
                                   sum = FALSE,
                                   # event_level = yardstick_event_level(),
                                   case_weights = NULL) {
  estimate <- dots_to_estimate(data, !!! enquos(...))
  
  metric_summarizer(
    metric_nm = "coverage",
    metric_fn = coverage_vec,
    data = data,
    truth = !!enquo(truth),
    estimate = !!estimate,
    na_rm = na_rm,
    # event_level = event_level,
    case_weights = !!enquo(case_weights),
    # Extra argument for mn_log_loss_impl()
    metric_fn_options = list(sum = sum)
  )
}

###' @rdname coverage
###' @export
coverage_vec <- function(truth,
                         estimate,
                         na_rm = TRUE,
                         sum = FALSE,
                         # event_level = yardstick_event_level(),
                         case_weights = NULL,
                         ...) {
  estimator <- finalize_estimator(truth, metric_class = "coverage")
  
  # estimate here is a matrix of class prob columns
  coverage_impl <- function(truth,
                            estimate,
                            ...,
                            sum = FALSE,
                            case_weights = NULL) {
    check_dots_empty()
    
    coverage_estimator_impl(
      truth = truth,
      estimate = estimate,
      estimator = estimator,
      # event_level = event_level,
      sum = sum,
      case_weights = case_weights
    )
  }
  
  metric_vec_template(
    metric_impl = coverage_impl,
    truth = truth,
    estimate = estimate,
    na_rm = na_rm,
    estimator = estimator,
    case_weights = case_weights,
    cls = "numeric",
    sum = sum
  )
}

coverage_estimator_impl <- function(truth,
                                       estimate,
                                       estimator,
                                       # event_level,
                                       sum,
                                       case_weights) {
  
  mean(estimate[,1] & truth <= estimate[,2])
  # if (is_binary(estimator)) {
  #   mn_log_loss_binary(
  #     truth = truth,
  #     estimate = estimate,
  #     event_level = event_level,
  #     sum = sum,
  #     case_weights = case_weights
  #   )
  # }
  # else {
  #   mn_log_loss_multiclass(
  #     truth = truth,
  #     estimate = estimate,
  #     sum = sum,
  #     case_weights = case_weights
  #   )
  # }
}

Error message from above attempt

Results in following error:

truth <- c(0.5:8.5, 11)
intervals <- matrix(c(0:9, 1:10), ncol = 2)

coverage_vec(truth, intervals)
#> Error in `validate_truth_estimate_types()`:
#>   ! `estimate` should be a numeric vector, not a numeric matrix.
#> Run `rlang::last_error()` to see where the error occurred.

Notes

Not sure if this constitutes a feature request or if I'm just doing something wrong. If the latter, a little more documentation on how to do custom numeric metrics that take in multiple inputs for estimate may be useful (e.g. Pinball Loss would make a good example). Though feel free to tell me this issue actually belongs on Rstudio Community or SO.

Related threads

brshallo avatar May 16 '22 19:05 brshallo

Thanks for your patience on this! I spent some time looking through this, and you haven't done anything wrong.

My guess is that perhaps a matrix input is only allowed in the case of prob metrics (as have examples of matrix input e.g. for mn_log_loss), though not sure...

This is correct, and yep, this is how validate_truth_estimate_checks() works right now. Matrices for estimate are currently allowed for factor truth but not numeric truth.

I'm not totally sure why we have entirely disallowed this, but we are currently checking with a test like:

library(yardstick)
rmse_vec(solubility_test$solubility, matrix(1:5))
#> Error in `validate_truth_estimate_types()`:
#> ! `estimate` should be a numeric vector, not a numeric matrix.

Created on 2022-05-23 by the reprex package (v2.0.1)

Let's wait for someone with deeper knowledge of what's going on in yardstick to chime in here with more context.

juliasilge avatar May 23 '22 22:05 juliasilge

Any updates on thinking here?

This kind of works except for data type checking:

library(tidymodels)

coverage_impl <- function(truth, estimate, case_weights = NULL) {
  covs <- estimate[,1] & truth <= estimate[,2]
  
  yardstick:::yardstick_mean(covs, case_weights = case_weights)
}

coverage_vec <- function(truth, estimate, na_rm = TRUE, case_weights = NULL, ...) {
  check_numeric_metric(truth, estimate, case_weights)
  
  if (na_rm) {
    result <- yardstick_remove_missing(truth, estimate, case_weights)
    
    truth <- result$truth
    estimate <- result$estimate
    case_weights <- result$case_weights
  } else if (yardstick_any_missing(truth, estimate, case_weights)) {
    return(NA_real_)
  }
  
  coverage_impl(truth, estimate, case_weights = case_weights)
}

Returns correct value; doesn't error (again, just no type checking):

truth <- c(0.5:8.5, 11)
intervals <- matrix(c(0:9, 1:10), ncol = 2)
coverage_vec(truth, intervals)

To add type checking would require a variant on check_numeric_metric() that expected the estimate to be a numeric matrix. I.e. substituting validate_numeric_truth_numeric_estimate() for a new internal function validate_numeric_truth_matrix_estimate() (similar to, in the classification context, the difference in validate_*() functions internal to check_class_metric() and check_prob_metric()).


Related changes / additions would need to be made on the dataframe implementation side and to allow ... in place of the estimate argument. I'll note though that just telling yardstick that it's a "prob_metric" does allow it to return an output...

library(rlang)

coverage <- function(data, ...) {
  UseMethod("coverage")
}

coverage <- new_prob_metric(coverage, direction = "maximize")

coverage.data.frame <- function(data, truth, ..., na_rm = TRUE, case_weights = NULL) {
  
  prob_metric_summarizer(
    name = "coverage",
    fn = coverage_vec,
    data = data,
    truth = !!enquo(truth),
    ...,
    na_rm = na_rm,
    case_weights = !!enquo(case_weights)
  )
}

I.e. this also doesn't error:

truth <- c(0.5:8.5, 11)
intervals <- matrix(c(0:9, 1:10), ncol = 2)

tibble(truth = truth, lower = intervals[, 1], upper = intervals[, 2]) %>% 
  coverage(truth, lower, upper)

brshallo avatar Feb 14 '24 07:02 brshallo