recipes icon indicating copy to clipboard operation
recipes copied to clipboard

step_quantile

Open ttrodrigz opened this issue 1 year ago • 2 comments

Feature

This idea came to mind after posting here on the Posit Community page which was turned into PR #1075 .

I believe it would be useful to have a step_quantile() recipe step. I realize that step_discretize() already exists, but it would be nice to have some additional control via the optional arguments available in stats::quantile() such as the type argument.

This function would work very similarly to step_percentile(), the main difference being that after the breakpoints are created with quantile(), the data would then be passed along to cut() rather than creating the percentiles with approx(), and the result would be an integer value.

I have a working example below which even incorporates @EmilHvitfeldt's implementation of the outside argument.

Few notes on default settings chosen for the function:

  • In cut(), include.lowest is set to FALSE by default. I believe this is sub-optimal for this recipe step because this setting would result in the minimum value in the data receiving a value of NA. This is why in the code below I set include.lowest = TRUE by default.
  • Set labels = FALSE to make sure an integer is returned, this overrides any user input
  • Set default na.rm = TRUE for quantile_options

Thanks for taking the time to read this!

Reprex

# Packages ----------------------------------------------------------------

library(tidyverse)
library(tidymodels)


# User-facing function ----------------------------------------------------

step_quantile <- function(
        recipe,
        ...,
        role = NA,
        trained = FALSE,
        ref_dist = NULL,
        quantile_options = list(probs = (0:5) / 5, na.rm = TRUE),
        cut_options = list(labels = FALSE, include.lowest = TRUE, right = TRUE),
        outside = "none",
        skip = FALSE,
        id = rand_id("quantile")
) {
    
    outside <- rlang::arg_match(
        outside,
        values = c("none", "both", "upper", "lower")
    )
    
    add_step(
        recipe,
        step_quantile_new(
            terms = enquos(...),
            trained = trained,
            role = role,
            ref_dist = ref_dist,
            quantile_options = quantile_options,
            cut_options = cut_options,
            outside = outside,
            skip = skip,
            id = id,
            case_weights = NULL
        )
    )
}


# Initialize new recipe step ----------------------------------------------

step_quantile_new <- function(
        terms, role, trained, ref_dist, quantile_options, cut_options, outside, skip, id, case_weights
) {
    step(
        subclass = "quantile",
        terms = terms,
        role = role,
        trained = trained,
        ref_dist = ref_dist,
        quantile_options = quantile_options,
        cut_options = cut_options,
        outside = outside,
        skip = skip,
        id = id,
        case_weights = case_weights
    )
}


# Function to calculate quantile ------------------------------------------

# Note these were stolen directly from step_percentile()

get_train_pctl <- function(x, wts, quantile_args = NULL) {
    
    if (is.null(wts)) {
        res <- rlang::exec("quantile", x = x, !!!quantile_args)
    } else {
        wts <- as.double(wts)
        res <- rlang::exec("weighted_quantile", x = x, wts = wts, !!!quantile_args)
    }
    
    # Remove duplicate percentile values
    res[!duplicated(res)]
}

weighted_quantile <- function(x, wts, probs, ...) {
    order_x <- order(x)
    x <- x[order_x]
    wts <- wts[order_x]
    
    wts_norm <- cumsum(wts) / sum(wts)
    res <- purrr::map_dbl(probs, ~x[min(which(wts_norm >= .x))])
    
    names(res) <- paste0(probs * 100, "%")
    res
}


# Prep method -------------------------------------------------------------

prep.step_quantile <- function(x, training, info = NULL, ...) {
    
    col_names <- recipes_eval_select(x$terms, training, info)
    check_type(training[, col_names], quant = TRUE)
    
    wts <- get_case_weights(info, training)
    were_weights_used <- are_weights_used(wts, unsupervised = TRUE)
    if (isFALSE(were_weights_used)) {
        wts <- NULL
    }
    
    ## We'll use the names later so make sure they are available
    x$quantile_options$names <- TRUE
    
    if (!any(names(x$quantile_options) == "probs")) {
        x$quantile_options$probs <- (0:5) / 5
    } else {
        x$quantile_options$probs <- sort(unique(x$quantile_options$probs))
    }
    
    if (!any(names(x$quantile_options) == "na.rm")) {
        x$quantile_options$na.rm <- TRUE
    }
    
    x$cut_options$names <- TRUE
    
    # overriding user input, always returning an integer rather
    # than an ordered factor
    x$cut_options$labels <- FALSE
    
    if (!any(names(x$cut_options) == "include.lowest")) {
        x$cut_options$include.lowest <- TRUE
    }    
    
    if (!any(names(x$cut_options) == "right")) {
        x$cut_options$right <- TRUE
    }    
    
    ref_dist <- purrr::map(
        training[, col_names],
        get_train_pctl,
        wts = wts,
        quantile_args = x$quantile_options
    )
    
    step_quantile_new(
        terms = x$terms,
        trained = TRUE,
        role = x$role,
        ref_dist = ref_dist,
        quantile_options = x$quantile_options,
        cut_options = x$cut_options,
        outside = x$outside,
        skip = x$skip,
        id = x$id,
        case_weights = were_weights_used
    )
}


# Custom cut function -----------------------------------------------------

# This executes `cut()` and controls what happens to values outside
# of the range seen by the training data

cut_custom <- function(x, ref, cut_args, outside) {
    
    ref.rng <- range(ref, na.rm = TRUE)
    x.rng <- range(x, na.rm = TRUE)
    
    res <- rlang::exec("cut", x = x, breaks = ref, !!!cut_args)
    
    if (x.rng[1] < ref.rng[1] & outside %in% c("both", "lower")) {
        res[x < ref.rng[1]] <- 1
    }
    
    if (x.rng[2] > ref.rng[2] & outside %in% c("both", "upper")) {
        res[x > ref.rng[2]] <- length(ref)
    }
    
    res
    
}

# Bake method -------------------------------------------------------------

bake.step_quantile <- function(object, new_data, ...) {
    
    vars <- names(object$ref_dist)
    check_new_data(vars, object, new_data)
    
    new_data[, vars] <- purrr::map2_dfc(
        .x = new_data[, vars], 
        .y = object$ref_dist, 
        .f = ~cut_custom(
            x = .x, 
            ref = .y, 
            cut_args = object$cut_options,
            outside = object$outside
        )
    )
    
    new_data
}


# Print method ------------------------------------------------------------

print.step_quantile <-
    function(x, width = max(20, options()$width - 35), ...) {
        title <- "Quantile transformation on "
        print_step(names(x$ref_dist), x$terms, x$trained, title, width,
                   case_weights = x$case_weights)
        invisible(x)
    }




# Test --------------------------------------------------------------------

set.seed(111)
data_train <- tibble(
    x = sample(5:10, size = 15, replace = TRUE),
    y = rnorm(15)
)

data_train$x[[4]] <- NA

data_test <- tibble(
    x = c(NA, 4:11)
)

rec <-
    data_train %>%
    recipe(y ~ x) %>%
    step_quantile(
        all_predictors(),
        quantile_options = list(
            probs = (0:4)/4
        ),
        cut_options = list(
            type = 2
        ),
        outside = "both"
    )

rec.trained <- prep(rec)

bake(rec.trained, data_train)
#> # A tibble: 15 x 2
#>        x       y
#>    <int>   <dbl>
#>  1     4  1.83  
#>  2     2  0.291 
#>  3     3 -0.566 
#>  4    NA -0.288 
#>  5     1 -0.462 
#>  6     2 -0.573 
#>  7     4  0.243 
#>  8     2  0.0248
#>  9     3 -0.0769
#> 10     1  0.563 
#> 11     1  0.786 
#> 12     4  0.860 
#> 13     4 -0.444 
#> 14     1  1.46  
#> 15     3 -0.341
bake(rec.trained, data_test)
#> # A tibble: 9 x 1
#>       x
#>   <dbl>
#> 1    NA
#> 2     1
#> 3     1
#> 4     1
#> 5     2
#> 6     3
#> 7     4
#> 8     4
#> 9     5

Created on 2023-01-06 by the reprex package (v2.0.1)

ttrodrigz avatar Jan 06 '23 19:01 ttrodrigz