embed icon indicating copy to clipboard operation
embed copied to clipboard

catboost method to embed categorical variables

Open talegari opened this issue 3 years ago • 11 comments
trafficstars

Hi Emil, I am planning to implement a step_catboost (on these lines). IMHO, it should belong here.

Let me know if you are open for PR?

talegari avatar Jun 22 '22 13:06 talegari

Unfortunately catboost (the R package) is not on CRAN 😔 which is a blocker for us being able to implement catboost methods in our packages. You can see related discussion in catboost/catboost#439.

juliasilge avatar Jun 22 '22 14:06 juliasilge

hey Julia, step_catboost would not depend on catboost package. The step involves involves permutations and target encoding. Here is the python implementation of the same.

talegari avatar Jun 22 '22 15:06 talegari

Hey @talegari 👋

That sounds great! Feel free to open an issue, and ping me if you need any help or assistance!

EmilHvitfeldt avatar Jun 22 '22 17:06 EmilHvitfeldt

Hello @talegari 👋 Are you still interested opening a PR for this step? if not, then I will do it

EmilHvitfeldt avatar Mar 15 '23 22:03 EmilHvitfeldt

Hey @EmilHvitfeldt ... it just fell off the radar. I will submit a PR. I am planning on these lines. Let me know if you have a different suggestion.

talegari avatar Mar 16 '23 06:03 talegari

Amazing! That looks like a great place to start! Do you know when you will have time to work on this? No rush!

EmilHvitfeldt avatar Mar 16 '23 16:03 EmilHvitfeldt

by 24th Mar

ಗುರು, ಮಾರ್ಚ್ 16, 2023 ರಂದು 09:34 ಅಪರಾಹ್ನ ಸಮಯಕ್ಕೆ ರಂದು Emil Hvitfeldt < @.***> ಅವರು ಬರೆದಿದ್ದಾರೆ:

Amazing! That looks like a great place to start! Do you know when you will have time to work on this? No rush!

— Reply to this email directly, view it on GitHub https://github.com/tidymodels/embed/issues/138#issuecomment-1472260970, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACMTTW4C6ESAZ42ZCB7WVCLW4M2Y7ANCNFSM5ZQHRD2A . You are receiving this because you were mentioned.Message ID: @.***>

talegari avatar Mar 17 '23 06:03 talegari

hey @EmilHvitfeldt , there was an unforseen thing that stopped me working on this. This is to let you know that I am on it and will raise a PR shortly.

talegari avatar Mar 26 '23 18:03 talegari

no problem! It might not make it into the next {embed} release, but that is fine, we can send it in later

EmilHvitfeldt avatar Mar 27 '23 18:03 EmilHvitfeldt

@EmilHvitfeldt , I am one step away from raising a PR. I need your help in resolving a small issue. Here is the context:

I have implemented catboost encoder as a R6 class here:

Category encoder R6 class
# catboost encoder core logic
pacman::p_load("tidyverse")

#' catboost_encoder R6 class
#'
#' An R6 class to encode categorical variables with the CatBoost method.
#'
#' @name catboost_encoder
#' @docType class
#' @importFrom R6 R6Class
#'
#' @slot dataset The dataset to fit the encoder
#' @slot mean The mean of the response variable in the dataset
#' @slot varnames_to_encode The names of the categorical variables to encode
#' @slot response_varname The name of the response variable in the dataset
#' @slot is_fitted A flag indicating whether the encoder has been fitted
#' @slot a A hyperparameter to control the strength of the encoding
#'
#' @section Public methods: \describe{
#'   \item{\code{initialize(dataset)}}{Constructor method for the
#'   catboost_encoder class} \item{\code{fit(varnames_to_encode,
#'   response_varname, a = 1)}}{Fit the encoder to the data}
#'   \item{\code{transform(new_data = NULL)}}{Transform a new dataset using the
#'   fitted encoder} }
#'
#' @section Private methods: \describe{ \item{\code{encode_with_y(df,
#'   varname_to_encode, response_varname)}}{Encode a categorical variable using
#'   the response variable} \item{\code{encode_without_y(df, varname_to_encode,
#'   response_varname)}}{Encode a categorical variable without using the
#'   response variable} }
#'
#' @section Usage
#'
#'   catboost_encoder <- catboost_encoder$new(dataset)
#'   catboost_encoder$fit(varnames_to_encode, response_varname) 
#'   encoded_data <- catboost_encoder$transform(new_data)
#'
#' @export catboost_encoder
catboost_encoder = R6::R6Class(
  "catboost_encoder",
  public = list(
    
    dataset               = NULL,
    mean                  = NULL,
    varnames_to_encode    = NULL,
    response_varname      = NULL,
    is_fitted             = FALSE,
    a                     = NULL,
    encode_novel_levels   = NULL,
    encode_missing_levels = NULL,
    
    initialize = function(dataset){
      checkmate::assert_data_frame(dataset)
      self$dataset = dataset
      return(invisible(NULL))
    },
    
    fit = function(varnames_to_encode,
                   response_varname,
                   a = 1,
                   encode_novel_levels = TRUE,
                   encode_missing_levels = FALSE
                   ){
      
      checkmate::assert_string(response_varname)
      checkmate::assert_subset(response_varname,
                               choices = colnames(self$dataset)
                               )
      checkmate::assert_numeric(self$dataset[[response_varname]],
                                any.missing = FALSE
                                )
      checkmate::assert_character(varnames_to_encode)
      checkmate::assert_subset(varnames_to_encode,
                               choices = colnames(self$dataset)
                               )
      for (avarname in varnames_to_encode){
        checkmate::assert_factor(self$dataset[[avarname]])
      }
      
      checkmate::assert_number(a)
      checkmate::assert_flag(encode_novel_levels)
      checkmate::assert_flag(encode_missing_levels)
      
      self$varnames_to_encode = varnames_to_encode
      self$response_varname = response_varname
      self$mean = mean(self$dataset[[response_varname]], na.rm = TRUE)
      self$a = a
      self$encode_novel_levels = TRUE
      self$encode_missing_levels = FALSE
      
      self$is_fitted = TRUE
      return(invisible(NULL))
    },
    
    transform = function(new_data = NULL){
      new_data_is_null = TRUE
      if (!is.null(new_data)){
        checkmate::assert_data_frame(new_data)
        checkmate::assert_false(self$response_varname %in% colnames(new_data))
        names_sorted = sort(colnames(new_data))
        checkmate::assert_set_equal(colnames(new_data),
                                    setdiff(colnames(self$dataset),
                                            self$response_varname
                                            )
                                    )
        checkmate::assert_set_equal(
          sapply(new_data, class)[names_sorted],
          sapply(dplyr::select(self$dataset, -c(self$response_varname))
                 , class
                 )[names_sorted]
          )
        new_data_is_null = FALSE
      }
      
      if (!self$is_fitted){
        stop("please 'fit' before 'transform'")
      }
      
      if (new_data_is_null){
        message("transforming on the dataset")
        new_data = self$dataset
      }
      
      if (new_data_is_null){
        encoded_cols = map(self$varnames_to_encode,
                           ~ private$encode_with_y(new_data, .x)
                           )
        
      } else {
        encoded_cols = map(self$varnames_to_encode,
                           ~ private$encode_without_y(new_data,.x)
                           )
      }
      
      names(encoded_cols) = self$varnames_to_encode
      
      res = as_tibble(encoded_cols) %>% 
        bind_cols(select(new_data, -c(self$varnames_to_encode))) %>% 
        relocate(colnames(new_data))
      
      # encode novel (in new data case only)
      if (self$encode_novel_levels && !new_data_is_null){
        for (avarname in self$varnames_to_encode){
          new_levels = setdiff(levels(new_data[[avarname]]),
                               levels(self$dataset[[avarname]])
                               )
          if (length(new_levels) > 0){
            res[[avarname]] = ifelse(new_data[[avarname]] %in% new_levels,
                                     self$mean,
                                     res[[avarname]]
                                     )
          }
        }
      }
            
      # encode missing (in new data case only)
      if (self$encode_missing_levels && !new_data_is_null){
        for (avarname in self$varnames_to_encode){
          res[[avarname]][ is.na(new_data[[avarname]]) ] = NA
        }
      }
      
      return(res)
    }
  ),
  private = list(
    
    encode_with_y = function(df, varname_to_encode){
      
      # new levels: not applicable
      # NA: encoded
      
      res = df %>% 
        select(all_of(c(varname_to_encode, self$response_varname))) %>% 
        group_by(.data[[varname_to_encode]]) %>% 
        mutate(cs__ = cumsum(.data[[self$response_varname]]),
               cc__ = row_number() - 1L
               ) %>% 
        ungroup() %>% 
        transmute({{varname_to_encode}} := (cs__ -
                      .data[[self$response_varname]] +
                      mean(.data[[self$response_varname]], na.rm = TRUE) *
                      self$a
                      ) / (cc__ + self$a)
               ) %>% 
        pull()
      
      return(res)
    },
    
    encode_without_y = function(df, varname_to_encode){
      
      # new levels: NA
      # NA: NA
      
      level_means = "level_means__"
      
      agg_frame = self$dataset %>% 
        select(all_of(c(varname_to_encode, self$response_varname))) %>% 
        group_by(.data[[varname_to_encode]]) %>% 
        summarise(sum__ = sum(.data[[self$response_varname]], na.rm = TRUE),
                  count__ = n()
                  ) %>% 
        ungroup() %>% 
        mutate(level_means__ = 
                 ifelse(count__ == 1,
                        self$mean,
                        (sum__ + self$mean * self$a) / (count__ + self$a)
                        )
               ) %>% 
        drop_na(all_of(varname_to_encode)) %>% 
        select(all_of(c(varname_to_encode, level_means)))
      
      res = df %>% 
        select(all_of(c(varname_to_encode))) %>%
        left_join(agg_frame, by = varname_to_encode) %>% 
        pull(level_means)
      
      return(res)
    }
    
  )
)
recipe wrapper as 'step_catboost'
step_catboost = function(recipe,
                         ...,
                         role = NA,
                         trained = FALSE,
                         outcome = NULL,
                         mapping = NULL,
                         skip = FALSE,
                         id = rand_id("catboost")
                         ){
    if (is.null(outcome)) {
      rlang::abort("Please list a variable in `outcome`")
    }
    recipes:::add_step(
      recipe,
      step_catboost_new(
        terms = enquos(...),
        role = role,
        trained = trained,
        outcome = outcome,
        mapping = mapping,
        skip = skip,
        id = id
      )
    )
  }

step_catboost_new = 
  function(terms,
           role,
           trained,
           outcome,
           mapping,
           skip,
           id
           ){
    step(
      subclass = "catboost",
      terms = terms,
      role = role,
      trained = trained,
      outcome = outcome,
      mapping = mapping,
      skip = skip,
      id = id
      )
  }

#' @export
prep.step_catboost = function(x,
                              training,
                              info = NULL,
                              ...
                              ){
  col_names = recipes_eval_select(x$terms, training, info)

  if (length(col_names) > 0) {
    y_name = recipes_eval_select(x$outcome, training, info)
    
    # instantiate R6 class obj
    ce = catboost_encoder$new(training)
    ce$fit(varnames_to_encode = col_names,
           response_varname = y_name
           )
  } else {
    ce = list()
  }
  step_catboost_new(
    terms = x$terms,
    role = x$role,
    trained = TRUE,
    outcome = x$outcome,
    mapping = ce,
    skip = x$skip,
    id = x$id
  )
}

#' @export
bake.step_catboost = function(object, new_data, ...) {
  
  if (!is.null(new_data)){
    y_name = purrr::map_chr(object$outcome, rlang::as_name) # string
    ce = object$mapping
    if (y_name %in% colnames(new_data)){
      new_data[[y_name]] = NULL
    }
    res = ce$transform(new_data)
  } else {
    res = ce$transform()
  }
  
  res = ce$transform(new_data)
  return(res)
}

#' @rdname required_pkgs.embed
#' @export
required_pkgs.step_catboost = function(x, ...) {
  c("embed")
}
Example
pacman::p_load("recipes", "tidyverse")
source("~/personal/catboost_encoding_r6.R")
#> transforming on the dataset
#> transforming on the dataset
source("~/personal/step_catboost.R")

pen1 = palmerpenguins::penguins %>% 
  drop_na(bill_length_mm) %>% 
  slice_sample(prop = 0.7, by = 'species')

pen2 = palmerpenguins::penguins %>% 
  drop_na(bill_length_mm) %>% 
  setdiff(pen1)

# example with R6 class
ce = catboost_encoder$new(pen1)
ce$fit(c('species', 'sex'), response_varname = 'bill_length_mm')

# when input to transofrm is empty, it uses the training dataset 
# (here it is pen1)
ce$transform()
#> transforming on the dataset
#> # A tibble: 238 × 8
#>    species island    bill_length_mm bill_depth_mm flipper_…¹ body_…²   sex  year
#>      <dbl> <fct>              <dbl>         <dbl>      <int>   <int> <dbl> <int>
#>  1    43.8 Torgersen           39.6          17.2        196    3550  43.8  2008
#>  2    41.7 Dream               37.5          18.9        179    2975  43.8  2007
#>  3    40.3 Biscoe              35.5          16.2        195    3350  41.7  2008
#>  4    39.1 Torgersen           40.6          19          199    4000  43.8  2009
#>  5    39.4 Biscoe              40.1          18.9        188    4300  42.2  2008
#>  6    39.5 Dream               39.6          18.8        190    4600  41.5  2007
#>  7    39.5 Dream               32.1          15.5        188    3050  39.6  2009
#>  8    38.6 Dream               39.8          19.1        184    4650  41.0  2007
#>  9    38.7 Torgersen           34.1          18.1        193    3475  40.6  2007
#> 10    38.3 Dream               37            16.9        185    3000  37.7  2007
#> # … with 228 more rows, and abbreviated variable names ¹​flipper_length_mm,
#> #   ²​body_mass_g

# transform on a new dataset
ce$transform(pen2 %>% select(-bill_length_mm))
#> # A tibble: 104 × 7
#>    species island    bill_depth_mm flipper_length_mm body_mass_g   sex  year
#>      <dbl> <fct>             <dbl>             <int>       <int> <dbl> <int>
#>  1    38.7 Torgersen          18                 195        3250  42.2  2007
#>  2    38.7 Torgersen          20.6               190        3650  45.6  2007
#>  3    38.7 Torgersen          17.8               181        3625  42.2  2007
#>  4    38.7 Torgersen          19.6               195        4675  45.6  2007
#>  5    38.7 Torgersen          21.2               191        3800  45.6  2007
#>  6    38.7 Torgersen          17.8               185        3700  42.2  2007
#>  7    38.7 Torgersen          20.7               197        4500  45.6  2007
#>  8    38.7 Torgersen          21.5               194        4200  45.6  2007
#>  9    38.7 Biscoe             18.6               172        3150  42.2  2007
#> 10    38.7 Dream              16.7               178        3250  42.2  2007
#> # … with 94 more rows

# example with step_catboost recipe
ar = recipe(bill_length_mm ~ ., data = pen1) %>% 
  step_catboost(species, outcome = "bill_length_mm") %>% 
  prep(training = pen1)

ar
#> Recipe
#> 
#> Inputs:
#> 
#>       role #variables
#>    outcome          1
#>  predictor          7
#> 
#> Training data contained 238 data points and 9 incomplete rows. 
#> 
#> Operations:
#> 
#> $terms
#> <list_of<quosure>>
#> 
#> [[1]]
#> <quosure>
#> expr: ^species
#> env:  0x7fbbb5a65120
#> 
#> 
#> $role
#> [1] NA
#> 
#> $trained
#> [1] TRUE
#> 
#> $outcome
#> [1] "bill_length_mm"
#> 
#> $mapping
#> <catboost_encoder>
#>   Public:
#>     a: 1
#>     clone: function (deep = FALSE) 
#>     dataset: tbl_df, tbl, data.frame
#>     encode_missing_levels: FALSE
#>     encode_novel_levels: TRUE
#>     fit: function (varnames_to_encode, response_varname, a = 1, encode_novel_levels = TRUE, 
#>     initialize: function (dataset) 
#>     is_fitted: TRUE
#>     mean: 43.7655462184874
#>     response_varname: bill_length_mm
#>     transform: function (new_data = NULL) 
#>     varnames_to_encode: species
#>   Private:
#>     encode_with_y: function (df, varname_to_encode) 
#>     encode_without_y: function (df, varname_to_encode) 
#> 
#> $skip
#> [1] FALSE
#> 
#> $id
#> [1] "catboost_LGVzz"
#> 
#> attr(,"class")
#> [1] "step_catboost" "step"

ar %>% 
  juice()
#> # A tibble: 238 × 7
#>    species island    bill_depth_mm flipper_length_mm body_mass_g sex     year
#>      <dbl> <fct>             <dbl>             <int>       <int> <fct>  <int>
#>  1    38.7 Torgersen          17.2               196        3550 female  2008
#>  2    38.7 Dream              18.9               179        2975 <NA>    2007
#>  3    38.7 Biscoe             16.2               195        3350 female  2008
#>  4    38.7 Torgersen          19                 199        4000 male    2009
#>  5    38.7 Biscoe             18.9               188        4300 male    2008
#>  6    38.7 Dream              18.8               190        4600 male    2007
#>  7    38.7 Dream              15.5               188        3050 female  2009
#>  8    38.7 Dream              19.1               184        4650 male    2007
#>  9    38.7 Torgersen          18.1               193        3475 <NA>    2007
#> 10    38.7 Dream              16.9               185        3000 female  2007
#> # … with 228 more rows

ar %>% 
  bake(new_data = NULL)
#> # A tibble: 238 × 7
#>    species island    bill_depth_mm flipper_length_mm body_mass_g sex     year
#>      <dbl> <fct>             <dbl>             <int>       <int> <fct>  <int>
#>  1    38.7 Torgersen          17.2               196        3550 female  2008
#>  2    38.7 Dream              18.9               179        2975 <NA>    2007
#>  3    38.7 Biscoe             16.2               195        3350 female  2008
#>  4    38.7 Torgersen          19                 199        4000 male    2009
#>  5    38.7 Biscoe             18.9               188        4300 male    2008
#>  6    38.7 Dream              18.8               190        4600 male    2007
#>  7    38.7 Dream              15.5               188        3050 female  2009
#>  8    38.7 Dream              19.1               184        4650 male    2007
#>  9    38.7 Torgersen          18.1               193        3475 <NA>    2007
#> 10    38.7 Dream              16.9               185        3000 female  2007
#> # … with 228 more rows

ar %>% 
  bake(new_data = pen1)
#> # A tibble: 238 × 7
#>    species island    bill_depth_mm flipper_length_mm body_mass_g sex     year
#>      <dbl> <fct>             <dbl>             <int>       <int> <fct>  <int>
#>  1    38.7 Torgersen          17.2               196        3550 female  2008
#>  2    38.7 Dream              18.9               179        2975 <NA>    2007
#>  3    38.7 Biscoe             16.2               195        3350 female  2008
#>  4    38.7 Torgersen          19                 199        4000 male    2009
#>  5    38.7 Biscoe             18.9               188        4300 male    2008
#>  6    38.7 Dream              18.8               190        4600 male    2007
#>  7    38.7 Dream              15.5               188        3050 female  2009
#>  8    38.7 Dream              19.1               184        4650 male    2007
#>  9    38.7 Torgersen          18.1               193        3475 <NA>    2007
#> 10    38.7 Dream              16.9               185        3000 female  2007
#> # … with 228 more rows

ar %>% 
  bake(new_data = pen2)
#> # A tibble: 104 × 7
#>    species island    bill_depth_mm flipper_length_mm body_mass_g sex     year
#>      <dbl> <fct>             <dbl>             <int>       <int> <fct>  <int>
#>  1    38.7 Torgersen          18                 195        3250 female  2007
#>  2    38.7 Torgersen          20.6               190        3650 male    2007
#>  3    38.7 Torgersen          17.8               181        3625 female  2007
#>  4    38.7 Torgersen          19.6               195        4675 male    2007
#>  5    38.7 Torgersen          21.2               191        3800 male    2007
#>  6    38.7 Torgersen          17.8               185        3700 female  2007
#>  7    38.7 Torgersen          20.7               197        4500 male    2007
#>  8    38.7 Torgersen          21.5               194        4200 male    2007
#>  9    38.7 Biscoe             18.6               172        3150 female  2007
#> 10    38.7 Dream              16.7               178        3250 female  2007
#> # … with 94 more rows

Issue: The ce$transform() and ar %>% bake(new_data = NULL) give different results. How do I resolve this?

talegari avatar Apr 04 '23 06:04 talegari

Hello @talegari Sorry for taking a while to answer.

I'm not terrible familiar with {R6} so I'm not sure how much I can help you. However, I can tell you where something might happen. In bake.step_catboost() you have

  if (!is.null(new_data)){
    y_name = purrr::map_chr(object$outcome, rlang::as_name) # string
    ce = object$mapping
    if (y_name %in% colnames(new_data)){
      new_data[[y_name]] = NULL
    }
    res = ce$transform(new_data)
  } else {
    res = ce$transform()
  }

I'm assuming that you thought this was needed to deal with bake(new_data = NULL). This is actually not the case, the data passed to any bake method will always be a non-NULL tibble. What is happening when you call bake(new_data = NULL) is that it extracts ar$template and does a couple of other things. So it just extracts the data we got when running prep/bake() the first time.

Secondly, I'm sad to say since you put in a lot of effort, but I don't want to include {R6} and {checkmate} as dependencies just to include this step. If you don't want to go through the work on translating away from {R6} and {checkmate} I understand, and If you want I can take over and do the last parts.

Thanks again for all the work!

EmilHvitfeldt avatar Apr 11 '23 20:04 EmilHvitfeldt