tune icon indicating copy to clipboard operation
tune copied to clipboard

logistic regression with glmnet engine when penalty = 0

Open GilHenriques opened this issue 2 years ago • 2 comments

The problem

When implementing a logistic regression with glmnet, I encounter two issues that I believe to be related. The reproducible example below showcases both issues. The issues arise when (as a reliability check), I set penalty = 0. The purpose of the check was to confirm that mixture has no effect when penalty = 0).

In short, the issues are:

  1. Even though a penalty value is explicitly provided in the model specification -- logistic_reg(penalty = 0, mixture = tune()) -- I get a "no_penalty()" error when tuning the workflow. This error is also obtained for values of penalty different from zero.
  2. In an effort to avoid this error, I set penalty = tune() and then include penalty = 0 in my tuning grid. The code then runs, but contrary to my expectation, the mixture had an effect on accuracy and ROC AUC.
  3. When I implement a similar model directly in the glmnet package, I confirm that when lambda = 0 (no penalty), there is no effect of alpha (mixture), whereas when lambda is larger than zero, there is an effect of alpha. This appears inconsistent with point 2 above.

Reproducible example

``` r
library(tidyverse)
library(tidymodels)

set.seed(123)

# Create an example data frame
df <- tibble(Y = sample(c(1, 0), 1000, replace = TRUE),
       X1 = rnorm(1000),
       X2 = rnorm(1000),
       X3 = rnorm(1000),
       X4 = rnorm(1000)) |> 
  mutate(Y = factor(Y))

# Initial split
splits <- initial_split(df)
train <- training(splits)
folds <- vfold_cv(train)

# Issue 1: No penalty error, even though a penalty is specified
model <- logistic_reg(penalty = 0, mixture = tune())|> set_engine('glmnet')
rec <- recipe(Y ~ ., data = train)
wflow <- workflow() |> add_model(model) |> add_recipe(rec)

wflow |> tune_grid(folds)
#> Error in `no_penalty()`:
#> ! At least one penalty value is required for glmnet.

#> Backtrace:
#>      ▆
#>   1. ├─tune::tune_grid(wflow, folds)
#>   2. └─tune:::tune_grid.workflow(wflow, folds)
#>   3.   └─tune:::tune_grid_workflow(...)
#>   4.     └─tune:::tune_grid_loop(...)
#>   5.       └─tune (local) fn_tune_grid_loop(...)
#>   6.         └─tune:::tune_grid_loop_impl(...)
#>   7.           └─tune:::compute_grid_info(workflow, grid)
#>   8.             └─tune:::compute_grid_info_model(workflow, grid, parameters_model)
#>   9.               ├─generics::min_grid(spec, grid)
#>  10.               └─tune::min_grid.logistic_reg(spec, grid)
#>  11.                 └─tune:::no_penalty(grid, sub_nm)
#>  12.                   └─rlang::abort("At least one penalty value is required for glmnet.")
# Error in `no_penalty()`:
# ! At least one penalty value is required for glmnet

# Issue 2: If penalty = 0 in the tuning grid, mixture still has an effect
model <- logistic_reg(penalty = tune(), mixture = tune())|> set_engine('glmnet')
rec <- recipe(Y ~ ., data = train)
wflow <- workflow() |> add_model(model) |> add_recipe(rec)

reg_grid <- expand_grid(penalty = 0, mixture = c(0.001, 0.01, 0.1, 0.25, 0.5, 0.6))

wflow |> tune_grid(folds, grid = reg_grid) |> 
  autoplot() # Parameter makes a difference even though penalty = 0


# Issue 3: When we use glmnet directly, if lambda = 0 alpha makes no difference
X <- df[1:500,-1] |> as.matrix()
Y <- df[1:500,] |> pull(Y)
fit1 <- glmnet::glmnet(X, Y, family = 'binomial', lambda = 0, alpha = 0.001)
fit2 <- glmnet::glmnet(X, Y, family = 'binomial', lambda = 0, alpha = 0.1)
fit3 <-  glmnet::glmnet(X, Y, family = 'binomial', lambda = 0, alpha = 0.5)

pred1 <- predict(fit1, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()
pred2 <- predict(fit2, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()
pred3 <- predict(fit3, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()

tibble((df[500:1000,1]), pred1, pred2, pred3) |> 
  mutate(Y = as.character(Y)) |> 
  summarize(accuracy1 = sum(pred1 == Y)/n(),
            accuracy2 = sum(pred2 == Y)/n(),
            accuracy3 = sum(pred3 == Y)/n())
#> # A tibble: 1 × 3
#>   accuracy1 accuracy2 accuracy3
#>       <dbl>     <dbl>     <dbl>
#> 1     0.507     0.507     0.507

# ... But if lambda > 0 alpha does make a difference
X <- df[1:500,-1] |> as.matrix()
Y <- df[1:500,] |> pull(Y)
fit1 <- glmnet::glmnet(X, Y, family = 'binomial', lambda = 0.1, alpha = 0.001)
fit2 <- glmnet::glmnet(X, Y, family = 'binomial', lambda = 0.1, alpha = 0.1)
fit3 <-  glmnet::glmnet(X, Y, family = 'binomial', lambda = 0.1, alpha = 0.5)

pred1 <- predict(fit1, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()
pred2 <- predict(fit2, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()
pred3 <- predict(fit3, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()

tibble((df[500:1000,1]), pred1, pred2, pred3) |> 
  mutate(Y = as.character(Y)) |> 
  summarize(accuracy1 = sum(pred1 == Y)/n(),
            accuracy2 = sum(pred2 == Y)/n(),
            accuracy3 = sum(pred3 == Y)/n())
#> # A tibble: 1 × 3
#>   accuracy1 accuracy2 accuracy3
#>       <dbl>     <dbl>     <dbl>
#> 1     0.509     0.489     0.491

Created on 2022-12-09 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.0 (2022-04-22)
#>  os       macOS Monterey 12.6
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/Stockholm
#>  date     2022-12-09
#>  pandoc   2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package       * version    date (UTC) lib source
#>  assertthat      0.2.1      2019-03-21 [1] CRAN (R 4.2.0)
#>  backports       1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>  broom         * 1.0.1      2022-08-29 [1] CRAN (R 4.2.0)
#>  cellranger      1.1.0      2016-07-27 [1] CRAN (R 4.2.0)
#>  class           7.3-20     2022-01-16 [1] CRAN (R 4.2.0)
#>  cli             3.4.1      2022-09-23 [1] CRAN (R 4.2.0)
#>  codetools       0.2-18     2020-11-04 [1] CRAN (R 4.2.0)
#>  colorspace      2.0-3      2022-02-21 [1] CRAN (R 4.2.0)
#>  crayon          1.5.2      2022-09-29 [1] CRAN (R 4.2.0)
#>  curl            4.3.2      2021-06-23 [1] CRAN (R 4.2.0)
#>  DBI             1.1.2      2021-12-20 [1] CRAN (R 4.2.0)
#>  dbplyr          2.2.0      2022-06-05 [1] CRAN (R 4.2.0)
#>  dials         * 1.0.0      2022-06-14 [1] CRAN (R 4.2.0)
#>  DiceDesign      1.9        2021-02-13 [1] CRAN (R 4.2.0)
#>  digest          0.6.29     2021-12-01 [1] CRAN (R 4.2.0)
#>  dplyr         * 1.0.10     2022-09-01 [1] CRAN (R 4.2.0)
#>  ellipsis        0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>  evaluate        0.16       2022-08-09 [1] CRAN (R 4.2.0)
#>  fansi           1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>  farver          2.1.1      2022-07-06 [1] CRAN (R 4.2.0)
#>  fastmap         1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>  forcats       * 0.5.2      2022-08-19 [1] CRAN (R 4.2.0)
#>  foreach         1.5.2      2022-02-02 [1] CRAN (R 4.2.0)
#>  fs              1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>  furrr           0.3.1      2022-08-15 [1] CRAN (R 4.2.0)
#>  future          1.27.0     2022-07-22 [1] CRAN (R 4.2.0)
#>  future.apply    1.9.0      2022-04-25 [1] CRAN (R 4.2.0)
#>  gargle          1.2.0      2021-07-02 [1] CRAN (R 4.2.0)
#>  generics        0.1.3      2022-07-05 [1] CRAN (R 4.2.0)
#>  ggplot2       * 3.3.6      2022-05-03 [1] CRAN (R 4.2.0)
#>  glmnet        * 4.1-4      2022-04-15 [1] CRAN (R 4.2.0)
#>  globals         0.15.1     2022-06-24 [1] CRAN (R 4.2.0)
#>  glue            1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  googledrive     2.0.0      2021-07-08 [1] CRAN (R 4.2.0)
#>  googlesheets4   1.0.0      2021-07-21 [1] CRAN (R 4.2.0)
#>  gower           1.0.0      2022-02-03 [1] CRAN (R 4.2.0)
#>  GPfit           1.0-8      2019-02-08 [1] CRAN (R 4.2.0)
#>  gtable          0.3.1      2022-09-01 [1] CRAN (R 4.2.0)
#>  hardhat         1.2.0      2022-06-30 [1] CRAN (R 4.2.0)
#>  haven           2.5.1      2022-08-22 [1] CRAN (R 4.2.0)
#>  highr           0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>  hms             1.1.2      2022-08-19 [1] CRAN (R 4.2.0)
#>  htmltools       0.5.3      2022-07-18 [1] CRAN (R 4.2.0)
#>  httr            1.4.3      2022-05-04 [1] CRAN (R 4.2.0)
#>  infer         * 1.0.2      2022-06-10 [1] CRAN (R 4.2.0)
#>  ipred           0.9-13     2022-06-02 [1] CRAN (R 4.2.0)
#>  iterators       1.0.14     2022-02-05 [1] CRAN (R 4.2.0)
#>  jsonlite        1.8.2      2022-10-02 [1] CRAN (R 4.2.0)
#>  knitr           1.40       2022-08-24 [1] CRAN (R 4.2.0)
#>  labeling        0.4.2      2020-10-20 [1] CRAN (R 4.2.0)
#>  lattice         0.20-45    2021-09-22 [1] CRAN (R 4.2.0)
#>  lava            1.6.10     2021-09-02 [1] CRAN (R 4.2.0)
#>  lhs             1.1.5      2022-03-22 [1] CRAN (R 4.2.0)
#>  lifecycle       1.0.2      2022-09-09 [1] CRAN (R 4.2.0)
#>  listenv         0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>  lubridate       1.8.0      2021-10-07 [1] CRAN (R 4.2.0)
#>  magrittr        2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>  MASS            7.3-56     2022-03-23 [1] CRAN (R 4.2.0)
#>  Matrix        * 1.5-1      2022-09-13 [1] CRAN (R 4.2.0)
#>  mime            0.12       2021-09-28 [1] CRAN (R 4.2.0)
#>  modeldata     * 1.0.0      2022-07-01 [1] CRAN (R 4.2.0)
#>  modelr          0.1.8      2020-05-19 [1] CRAN (R 4.2.0)
#>  munsell         0.5.0      2018-06-12 [1] CRAN (R 4.2.0)
#>  nnet            7.3-17     2022-01-13 [1] CRAN (R 4.2.0)
#>  parallelly      1.32.1     2022-07-21 [1] CRAN (R 4.2.0)
#>  parsnip       * 1.0.0      2022-06-16 [1] CRAN (R 4.2.0)
#>  pillar          1.8.1      2022-08-19 [1] CRAN (R 4.2.0)
#>  pkgconfig       2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>  prodlim         2019.11.13 2019-11-17 [1] CRAN (R 4.2.0)
#>  purrr         * 0.3.4      2020-04-17 [1] CRAN (R 4.2.0)
#>  R.cache         0.16.0     2022-07-21 [1] CRAN (R 4.2.0)
#>  R.methodsS3     1.8.2      2022-06-13 [1] CRAN (R 4.2.0)
#>  R.oo            1.25.0     2022-06-12 [1] CRAN (R 4.2.0)
#>  R.utils         2.12.2     2022-11-11 [1] CRAN (R 4.2.0)
#>  R6              2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>  Rcpp            1.0.9      2022-07-08 [1] CRAN (R 4.2.0)
#>  readr         * 2.1.3      2022-10-01 [1] CRAN (R 4.2.0)
#>  readxl          1.4.1      2022-08-17 [1] CRAN (R 4.2.0)
#>  recipes       * 1.0.1      2022-07-07 [1] CRAN (R 4.2.0)
#>  reprex          2.0.2      2022-08-17 [1] CRAN (R 4.2.0)
#>  rlang           1.0.6      2022-09-24 [1] CRAN (R 4.2.0)
#>  rmarkdown       2.14       2022-04-25 [1] CRAN (R 4.2.0)
#>  rpart           4.1.16     2022-01-24 [1] CRAN (R 4.2.0)
#>  rsample       * 1.0.0      2022-06-24 [1] CRAN (R 4.2.0)
#>  rstudioapi      0.13       2020-11-12 [1] CRAN (R 4.2.0)
#>  rvest           1.0.2      2021-10-16 [1] CRAN (R 4.2.0)
#>  scales        * 1.2.1      2022-08-20 [1] CRAN (R 4.2.0)
#>  sessioninfo     1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>  shape           1.4.6      2021-05-19 [1] CRAN (R 4.2.0)
#>  stringi         1.7.8      2022-07-11 [1] CRAN (R 4.2.0)
#>  stringr       * 1.4.1      2022-08-20 [1] CRAN (R 4.2.0)
#>  styler          1.8.1      2022-11-07 [1] CRAN (R 4.2.0)
#>  survival        3.3-1      2022-03-03 [1] CRAN (R 4.2.0)
#>  tibble        * 3.1.8      2022-07-22 [1] CRAN (R 4.2.0)
#>  tidymodels    * 1.0.0      2022-07-13 [1] CRAN (R 4.2.0)
#>  tidyr         * 1.2.1      2022-09-08 [1] CRAN (R 4.2.0)
#>  tidyselect      1.1.2      2022-02-21 [1] CRAN (R 4.2.0)
#>  tidyverse     * 1.3.2      2022-07-18 [1] CRAN (R 4.2.0)
#>  timeDate        4021.104   2022-07-19 [1] CRAN (R 4.2.0)
#>  tune          * 1.0.0      2022-07-07 [1] CRAN (R 4.2.0)
#>  tzdb            0.3.0      2022-03-28 [1] CRAN (R 4.2.0)
#>  utf8            1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>  vctrs           0.4.2      2022-09-29 [1] CRAN (R 4.2.0)
#>  withr           2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>  workflows     * 1.0.0      2022-07-05 [1] CRAN (R 4.2.0)
#>  workflowsets  * 1.0.0      2022-07-12 [1] CRAN (R 4.2.0)
#>  xfun            0.33       2022-09-12 [1] CRAN (R 4.2.0)
#>  xml2            1.3.3      2021-11-30 [1] CRAN (R 4.2.0)
#>  yaml            2.3.5      2022-02-21 [1] CRAN (R 4.2.0)
#>  yardstick     * 1.0.0      2022-06-06 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────
```

GilHenriques avatar Dec 09 '22 15:12 GilHenriques

Thank you for the issue! Just wanted to let you know this hasn't fallen off our radar. Related to https://github.com/tidymodels/tune/issues/28 and https://github.com/tidymodels/tune/issues/45.

simonpcouch avatar Oct 31 '23 19:10 simonpcouch

+1 Thank you @simonpcouch By the moment is there any way to solve it?

marcozanotti avatar Jan 16 '24 08:01 marcozanotti