embed
embed copied to clipboard
catboost method to embed categorical variables
Hi Emil,
I am planning to implement a step_catboost (on these lines). IMHO, it should belong here.
Let me know if you are open for PR?
Unfortunately catboost (the R package) is not on CRAN 😔 which is a blocker for us being able to implement catboost methods in our packages. You can see related discussion in catboost/catboost#439.
hey Julia, step_catboost would not depend on catboost package. The step involves involves permutations and target encoding. Here is the python implementation of the same.
Hey @talegari 👋
That sounds great! Feel free to open an issue, and ping me if you need any help or assistance!
Hello @talegari 👋 Are you still interested opening a PR for this step? if not, then I will do it
Hey @EmilHvitfeldt ... it just fell off the radar. I will submit a PR. I am planning on these lines. Let me know if you have a different suggestion.
Amazing! That looks like a great place to start! Do you know when you will have time to work on this? No rush!
by 24th Mar
ಗುರು, ಮಾರ್ಚ್ 16, 2023 ರಂದು 09:34 ಅಪರಾಹ್ನ ಸಮಯಕ್ಕೆ ರಂದು Emil Hvitfeldt < @.***> ಅವರು ಬರೆದಿದ್ದಾರೆ:
Amazing! That looks like a great place to start! Do you know when you will have time to work on this? No rush!
— Reply to this email directly, view it on GitHub https://github.com/tidymodels/embed/issues/138#issuecomment-1472260970, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACMTTW4C6ESAZ42ZCB7WVCLW4M2Y7ANCNFSM5ZQHRD2A . You are receiving this because you were mentioned.Message ID: @.***>
hey @EmilHvitfeldt , there was an unforseen thing that stopped me working on this. This is to let you know that I am on it and will raise a PR shortly.
no problem! It might not make it into the next {embed} release, but that is fine, we can send it in later
@EmilHvitfeldt , I am one step away from raising a PR. I need your help in resolving a small issue. Here is the context:
I have implemented catboost encoder as a R6 class here:
Category encoder R6 class
# catboost encoder core logic
pacman::p_load("tidyverse")
#' catboost_encoder R6 class
#'
#' An R6 class to encode categorical variables with the CatBoost method.
#'
#' @name catboost_encoder
#' @docType class
#' @importFrom R6 R6Class
#'
#' @slot dataset The dataset to fit the encoder
#' @slot mean The mean of the response variable in the dataset
#' @slot varnames_to_encode The names of the categorical variables to encode
#' @slot response_varname The name of the response variable in the dataset
#' @slot is_fitted A flag indicating whether the encoder has been fitted
#' @slot a A hyperparameter to control the strength of the encoding
#'
#' @section Public methods: \describe{
#' \item{\code{initialize(dataset)}}{Constructor method for the
#' catboost_encoder class} \item{\code{fit(varnames_to_encode,
#' response_varname, a = 1)}}{Fit the encoder to the data}
#' \item{\code{transform(new_data = NULL)}}{Transform a new dataset using the
#' fitted encoder} }
#'
#' @section Private methods: \describe{ \item{\code{encode_with_y(df,
#' varname_to_encode, response_varname)}}{Encode a categorical variable using
#' the response variable} \item{\code{encode_without_y(df, varname_to_encode,
#' response_varname)}}{Encode a categorical variable without using the
#' response variable} }
#'
#' @section Usage
#'
#' catboost_encoder <- catboost_encoder$new(dataset)
#' catboost_encoder$fit(varnames_to_encode, response_varname)
#' encoded_data <- catboost_encoder$transform(new_data)
#'
#' @export catboost_encoder
catboost_encoder = R6::R6Class(
"catboost_encoder",
public = list(
dataset = NULL,
mean = NULL,
varnames_to_encode = NULL,
response_varname = NULL,
is_fitted = FALSE,
a = NULL,
encode_novel_levels = NULL,
encode_missing_levels = NULL,
initialize = function(dataset){
checkmate::assert_data_frame(dataset)
self$dataset = dataset
return(invisible(NULL))
},
fit = function(varnames_to_encode,
response_varname,
a = 1,
encode_novel_levels = TRUE,
encode_missing_levels = FALSE
){
checkmate::assert_string(response_varname)
checkmate::assert_subset(response_varname,
choices = colnames(self$dataset)
)
checkmate::assert_numeric(self$dataset[[response_varname]],
any.missing = FALSE
)
checkmate::assert_character(varnames_to_encode)
checkmate::assert_subset(varnames_to_encode,
choices = colnames(self$dataset)
)
for (avarname in varnames_to_encode){
checkmate::assert_factor(self$dataset[[avarname]])
}
checkmate::assert_number(a)
checkmate::assert_flag(encode_novel_levels)
checkmate::assert_flag(encode_missing_levels)
self$varnames_to_encode = varnames_to_encode
self$response_varname = response_varname
self$mean = mean(self$dataset[[response_varname]], na.rm = TRUE)
self$a = a
self$encode_novel_levels = TRUE
self$encode_missing_levels = FALSE
self$is_fitted = TRUE
return(invisible(NULL))
},
transform = function(new_data = NULL){
new_data_is_null = TRUE
if (!is.null(new_data)){
checkmate::assert_data_frame(new_data)
checkmate::assert_false(self$response_varname %in% colnames(new_data))
names_sorted = sort(colnames(new_data))
checkmate::assert_set_equal(colnames(new_data),
setdiff(colnames(self$dataset),
self$response_varname
)
)
checkmate::assert_set_equal(
sapply(new_data, class)[names_sorted],
sapply(dplyr::select(self$dataset, -c(self$response_varname))
, class
)[names_sorted]
)
new_data_is_null = FALSE
}
if (!self$is_fitted){
stop("please 'fit' before 'transform'")
}
if (new_data_is_null){
message("transforming on the dataset")
new_data = self$dataset
}
if (new_data_is_null){
encoded_cols = map(self$varnames_to_encode,
~ private$encode_with_y(new_data, .x)
)
} else {
encoded_cols = map(self$varnames_to_encode,
~ private$encode_without_y(new_data,.x)
)
}
names(encoded_cols) = self$varnames_to_encode
res = as_tibble(encoded_cols) %>%
bind_cols(select(new_data, -c(self$varnames_to_encode))) %>%
relocate(colnames(new_data))
# encode novel (in new data case only)
if (self$encode_novel_levels && !new_data_is_null){
for (avarname in self$varnames_to_encode){
new_levels = setdiff(levels(new_data[[avarname]]),
levels(self$dataset[[avarname]])
)
if (length(new_levels) > 0){
res[[avarname]] = ifelse(new_data[[avarname]] %in% new_levels,
self$mean,
res[[avarname]]
)
}
}
}
# encode missing (in new data case only)
if (self$encode_missing_levels && !new_data_is_null){
for (avarname in self$varnames_to_encode){
res[[avarname]][ is.na(new_data[[avarname]]) ] = NA
}
}
return(res)
}
),
private = list(
encode_with_y = function(df, varname_to_encode){
# new levels: not applicable
# NA: encoded
res = df %>%
select(all_of(c(varname_to_encode, self$response_varname))) %>%
group_by(.data[[varname_to_encode]]) %>%
mutate(cs__ = cumsum(.data[[self$response_varname]]),
cc__ = row_number() - 1L
) %>%
ungroup() %>%
transmute({{varname_to_encode}} := (cs__ -
.data[[self$response_varname]] +
mean(.data[[self$response_varname]], na.rm = TRUE) *
self$a
) / (cc__ + self$a)
) %>%
pull()
return(res)
},
encode_without_y = function(df, varname_to_encode){
# new levels: NA
# NA: NA
level_means = "level_means__"
agg_frame = self$dataset %>%
select(all_of(c(varname_to_encode, self$response_varname))) %>%
group_by(.data[[varname_to_encode]]) %>%
summarise(sum__ = sum(.data[[self$response_varname]], na.rm = TRUE),
count__ = n()
) %>%
ungroup() %>%
mutate(level_means__ =
ifelse(count__ == 1,
self$mean,
(sum__ + self$mean * self$a) / (count__ + self$a)
)
) %>%
drop_na(all_of(varname_to_encode)) %>%
select(all_of(c(varname_to_encode, level_means)))
res = df %>%
select(all_of(c(varname_to_encode))) %>%
left_join(agg_frame, by = varname_to_encode) %>%
pull(level_means)
return(res)
}
)
)
recipe wrapper as 'step_catboost'
step_catboost = function(recipe,
...,
role = NA,
trained = FALSE,
outcome = NULL,
mapping = NULL,
skip = FALSE,
id = rand_id("catboost")
){
if (is.null(outcome)) {
rlang::abort("Please list a variable in `outcome`")
}
recipes:::add_step(
recipe,
step_catboost_new(
terms = enquos(...),
role = role,
trained = trained,
outcome = outcome,
mapping = mapping,
skip = skip,
id = id
)
)
}
step_catboost_new =
function(terms,
role,
trained,
outcome,
mapping,
skip,
id
){
step(
subclass = "catboost",
terms = terms,
role = role,
trained = trained,
outcome = outcome,
mapping = mapping,
skip = skip,
id = id
)
}
#' @export
prep.step_catboost = function(x,
training,
info = NULL,
...
){
col_names = recipes_eval_select(x$terms, training, info)
if (length(col_names) > 0) {
y_name = recipes_eval_select(x$outcome, training, info)
# instantiate R6 class obj
ce = catboost_encoder$new(training)
ce$fit(varnames_to_encode = col_names,
response_varname = y_name
)
} else {
ce = list()
}
step_catboost_new(
terms = x$terms,
role = x$role,
trained = TRUE,
outcome = x$outcome,
mapping = ce,
skip = x$skip,
id = x$id
)
}
#' @export
bake.step_catboost = function(object, new_data, ...) {
if (!is.null(new_data)){
y_name = purrr::map_chr(object$outcome, rlang::as_name) # string
ce = object$mapping
if (y_name %in% colnames(new_data)){
new_data[[y_name]] = NULL
}
res = ce$transform(new_data)
} else {
res = ce$transform()
}
res = ce$transform(new_data)
return(res)
}
#' @rdname required_pkgs.embed
#' @export
required_pkgs.step_catboost = function(x, ...) {
c("embed")
}
Example
pacman::p_load("recipes", "tidyverse")
source("~/personal/catboost_encoding_r6.R")
#> transforming on the dataset
#> transforming on the dataset
source("~/personal/step_catboost.R")
pen1 = palmerpenguins::penguins %>%
drop_na(bill_length_mm) %>%
slice_sample(prop = 0.7, by = 'species')
pen2 = palmerpenguins::penguins %>%
drop_na(bill_length_mm) %>%
setdiff(pen1)
# example with R6 class
ce = catboost_encoder$new(pen1)
ce$fit(c('species', 'sex'), response_varname = 'bill_length_mm')
# when input to transofrm is empty, it uses the training dataset
# (here it is pen1)
ce$transform()
#> transforming on the dataset
#> # A tibble: 238 × 8
#> species island bill_length_mm bill_depth_mm flipper_…¹ body_…² sex year
#> <dbl> <fct> <dbl> <dbl> <int> <int> <dbl> <int>
#> 1 43.8 Torgersen 39.6 17.2 196 3550 43.8 2008
#> 2 41.7 Dream 37.5 18.9 179 2975 43.8 2007
#> 3 40.3 Biscoe 35.5 16.2 195 3350 41.7 2008
#> 4 39.1 Torgersen 40.6 19 199 4000 43.8 2009
#> 5 39.4 Biscoe 40.1 18.9 188 4300 42.2 2008
#> 6 39.5 Dream 39.6 18.8 190 4600 41.5 2007
#> 7 39.5 Dream 32.1 15.5 188 3050 39.6 2009
#> 8 38.6 Dream 39.8 19.1 184 4650 41.0 2007
#> 9 38.7 Torgersen 34.1 18.1 193 3475 40.6 2007
#> 10 38.3 Dream 37 16.9 185 3000 37.7 2007
#> # … with 228 more rows, and abbreviated variable names ¹flipper_length_mm,
#> # ²body_mass_g
# transform on a new dataset
ce$transform(pen2 %>% select(-bill_length_mm))
#> # A tibble: 104 × 7
#> species island bill_depth_mm flipper_length_mm body_mass_g sex year
#> <dbl> <fct> <dbl> <int> <int> <dbl> <int>
#> 1 38.7 Torgersen 18 195 3250 42.2 2007
#> 2 38.7 Torgersen 20.6 190 3650 45.6 2007
#> 3 38.7 Torgersen 17.8 181 3625 42.2 2007
#> 4 38.7 Torgersen 19.6 195 4675 45.6 2007
#> 5 38.7 Torgersen 21.2 191 3800 45.6 2007
#> 6 38.7 Torgersen 17.8 185 3700 42.2 2007
#> 7 38.7 Torgersen 20.7 197 4500 45.6 2007
#> 8 38.7 Torgersen 21.5 194 4200 45.6 2007
#> 9 38.7 Biscoe 18.6 172 3150 42.2 2007
#> 10 38.7 Dream 16.7 178 3250 42.2 2007
#> # … with 94 more rows
# example with step_catboost recipe
ar = recipe(bill_length_mm ~ ., data = pen1) %>%
step_catboost(species, outcome = "bill_length_mm") %>%
prep(training = pen1)
ar
#> Recipe
#>
#> Inputs:
#>
#> role #variables
#> outcome 1
#> predictor 7
#>
#> Training data contained 238 data points and 9 incomplete rows.
#>
#> Operations:
#>
#> $terms
#> <list_of<quosure>>
#>
#> [[1]]
#> <quosure>
#> expr: ^species
#> env: 0x7fbbb5a65120
#>
#>
#> $role
#> [1] NA
#>
#> $trained
#> [1] TRUE
#>
#> $outcome
#> [1] "bill_length_mm"
#>
#> $mapping
#> <catboost_encoder>
#> Public:
#> a: 1
#> clone: function (deep = FALSE)
#> dataset: tbl_df, tbl, data.frame
#> encode_missing_levels: FALSE
#> encode_novel_levels: TRUE
#> fit: function (varnames_to_encode, response_varname, a = 1, encode_novel_levels = TRUE,
#> initialize: function (dataset)
#> is_fitted: TRUE
#> mean: 43.7655462184874
#> response_varname: bill_length_mm
#> transform: function (new_data = NULL)
#> varnames_to_encode: species
#> Private:
#> encode_with_y: function (df, varname_to_encode)
#> encode_without_y: function (df, varname_to_encode)
#>
#> $skip
#> [1] FALSE
#>
#> $id
#> [1] "catboost_LGVzz"
#>
#> attr(,"class")
#> [1] "step_catboost" "step"
ar %>%
juice()
#> # A tibble: 238 × 7
#> species island bill_depth_mm flipper_length_mm body_mass_g sex year
#> <dbl> <fct> <dbl> <int> <int> <fct> <int>
#> 1 38.7 Torgersen 17.2 196 3550 female 2008
#> 2 38.7 Dream 18.9 179 2975 <NA> 2007
#> 3 38.7 Biscoe 16.2 195 3350 female 2008
#> 4 38.7 Torgersen 19 199 4000 male 2009
#> 5 38.7 Biscoe 18.9 188 4300 male 2008
#> 6 38.7 Dream 18.8 190 4600 male 2007
#> 7 38.7 Dream 15.5 188 3050 female 2009
#> 8 38.7 Dream 19.1 184 4650 male 2007
#> 9 38.7 Torgersen 18.1 193 3475 <NA> 2007
#> 10 38.7 Dream 16.9 185 3000 female 2007
#> # … with 228 more rows
ar %>%
bake(new_data = NULL)
#> # A tibble: 238 × 7
#> species island bill_depth_mm flipper_length_mm body_mass_g sex year
#> <dbl> <fct> <dbl> <int> <int> <fct> <int>
#> 1 38.7 Torgersen 17.2 196 3550 female 2008
#> 2 38.7 Dream 18.9 179 2975 <NA> 2007
#> 3 38.7 Biscoe 16.2 195 3350 female 2008
#> 4 38.7 Torgersen 19 199 4000 male 2009
#> 5 38.7 Biscoe 18.9 188 4300 male 2008
#> 6 38.7 Dream 18.8 190 4600 male 2007
#> 7 38.7 Dream 15.5 188 3050 female 2009
#> 8 38.7 Dream 19.1 184 4650 male 2007
#> 9 38.7 Torgersen 18.1 193 3475 <NA> 2007
#> 10 38.7 Dream 16.9 185 3000 female 2007
#> # … with 228 more rows
ar %>%
bake(new_data = pen1)
#> # A tibble: 238 × 7
#> species island bill_depth_mm flipper_length_mm body_mass_g sex year
#> <dbl> <fct> <dbl> <int> <int> <fct> <int>
#> 1 38.7 Torgersen 17.2 196 3550 female 2008
#> 2 38.7 Dream 18.9 179 2975 <NA> 2007
#> 3 38.7 Biscoe 16.2 195 3350 female 2008
#> 4 38.7 Torgersen 19 199 4000 male 2009
#> 5 38.7 Biscoe 18.9 188 4300 male 2008
#> 6 38.7 Dream 18.8 190 4600 male 2007
#> 7 38.7 Dream 15.5 188 3050 female 2009
#> 8 38.7 Dream 19.1 184 4650 male 2007
#> 9 38.7 Torgersen 18.1 193 3475 <NA> 2007
#> 10 38.7 Dream 16.9 185 3000 female 2007
#> # … with 228 more rows
ar %>%
bake(new_data = pen2)
#> # A tibble: 104 × 7
#> species island bill_depth_mm flipper_length_mm body_mass_g sex year
#> <dbl> <fct> <dbl> <int> <int> <fct> <int>
#> 1 38.7 Torgersen 18 195 3250 female 2007
#> 2 38.7 Torgersen 20.6 190 3650 male 2007
#> 3 38.7 Torgersen 17.8 181 3625 female 2007
#> 4 38.7 Torgersen 19.6 195 4675 male 2007
#> 5 38.7 Torgersen 21.2 191 3800 male 2007
#> 6 38.7 Torgersen 17.8 185 3700 female 2007
#> 7 38.7 Torgersen 20.7 197 4500 male 2007
#> 8 38.7 Torgersen 21.5 194 4200 male 2007
#> 9 38.7 Biscoe 18.6 172 3150 female 2007
#> 10 38.7 Dream 16.7 178 3250 female 2007
#> # … with 94 more rows
Issue: The ce$transform() and ar %>% bake(new_data = NULL) give different results. How do I resolve this?
Hello @talegari Sorry for taking a while to answer.
I'm not terrible familiar with {R6} so I'm not sure how much I can help you. However, I can tell you where something might happen. In bake.step_catboost() you have
if (!is.null(new_data)){
y_name = purrr::map_chr(object$outcome, rlang::as_name) # string
ce = object$mapping
if (y_name %in% colnames(new_data)){
new_data[[y_name]] = NULL
}
res = ce$transform(new_data)
} else {
res = ce$transform()
}
I'm assuming that you thought this was needed to deal with bake(new_data = NULL). This is actually not the case, the data passed to any bake method will always be a non-NULL tibble. What is happening when you call bake(new_data = NULL) is that it extracts ar$template and does a couple of other things. So it just extracts the data we got when running prep/bake() the first time.
Secondly, I'm sad to say since you put in a lot of effort, but I don't want to include {R6} and {checkmate} as dependencies just to include this step. If you don't want to go through the work on translating away from {R6} and {checkmate} I understand, and If you want I can take over and do the last parts.
Thanks again for all the work!