parsnip icon indicating copy to clipboard operation
parsnip copied to clipboard

`augment`ing with training data

Open simonpcouch opened this issue 3 years ago • 1 comments

This issue came up in a conversation with @mine-cetinkaya-rundel about teaching introductory stats / modeling courses using the tidymodels. I feel that, in some ways, parsnip’s guardrails re: augment make teaching broom’s principles fussier than it ought to. Fitting a model and passing it to each tidier:

library(tidyverse)
library(tidymodels)
library(palmerpenguins)

penguins <- drop_na(penguins)

penguins_tr <- penguins[1:200,]
penguins_te <- penguins[201:nrow(penguins),]

pars_fit <- linear_reg() %>%
  set_engine("lm") %>%
  fit(body_mass_g ~ flipper_length_mm, data = penguins)

# works
tidy(pars_fit)
#> # A tibble: 2 × 5
#>   term              estimate std.error statistic   p.value
#>   <chr>                <dbl>     <dbl>     <dbl>     <dbl>
#> 1 (Intercept)        -5872.     310.       -18.9 1.18e- 54
#> 2 flipper_length_mm     50.2      1.54      32.6 3.13e-105

# works
glance(pars_fit)
#> # A tibble: 1 × 12
#>   r.squared adj.r.squared sigma statistic   p.value    df logLik   AIC   BIC
#>       <dbl>         <dbl> <dbl>     <dbl>     <dbl> <dbl>  <dbl> <dbl> <dbl>
#> 1     0.762         0.761  393.     1060. 3.13e-105     1 -2461. 4928. 4940.
#> # … with 3 more variables: deviance <dbl>, df.residual <int>, nobs <int>

# oh-
augment(pars_fit)
#> Error in augment.model_fit(pars_fit): argument "new_data" is missing, with no default

I understand that the intention here was to guard folks from predicting on the training set, and maybe the conclusion here is that the teaching moment re: predicting on training data needs to happen this early on. I feel that the current approach feels 1) dogmatic and, in some cases, 2) ends up encouraging the opposite behavior.

re: 1) broom’s augment methods distinguish between data (i.e. training data) and newdata in determining output, and returns other fit info when supplied the former rather than the latter that’s only well-defined for training data. augmenting with training data accommodates discussing these values:

lm_fit <- lm(body_mass_g ~ flipper_length_mm, data = penguins_tr)

# default: retrieve `penguins_tr` and pass as `data`
augment(lm_fit)
#> # A tibble: 200 × 8
#>    body_mass_g flipper_length_… .fitted .resid    .hat .sigma .cooksd .std.resid
#>          <int>            <int>   <dbl>  <dbl>   <dbl>  <dbl>   <dbl>      <dbl>
#>  1        3750              181   3273.  477.  0.0127    415. 8.58e-3     1.16  
#>  2        3800              186   3523.  277.  0.00862   416. 1.96e-3     0.671 
#>  3        3250              195   3972. -722.  0.00511   413. 7.81e-3    -1.74  
#>  4        3450              193   3872. -422.  0.00547   415. 2.85e-3    -1.02  
#>  5        3650              190   3722.  -72.2 0.00645   416. 9.89e-5    -0.175 
#>  6        3625              181   3273.  352.  0.0127    415. 4.67e-3     0.853 
#>  7        4675              195   3972.  703.  0.00511   413. 7.42e-3     1.70  
#>  8        3200              182   3323. -123.  0.0117    416. 5.31e-4    -0.299 
#>  9        3800              191   3772.   27.9 0.00606   416. 1.39e-5     0.0675
#> 10        4400              198   4121.  279.  0.00504   416. 1.15e-3     0.674 
#> # … with 190 more rows

# pass `penguins_tr` as `data` explicitly
augment(lm_fit, data = penguins_tr)
#> # A tibble: 200 × 14
#>    species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>   <fct>              <dbl>         <dbl>             <int>       <int>
#>  1 Adelie  Torgersen           39.1          18.7               181        3750
#>  2 Adelie  Torgersen           39.5          17.4               186        3800
#>  3 Adelie  Torgersen           40.3          18                 195        3250
#>  4 Adelie  Torgersen           36.7          19.3               193        3450
#>  5 Adelie  Torgersen           39.3          20.6               190        3650
#>  6 Adelie  Torgersen           38.9          17.8               181        3625
#>  7 Adelie  Torgersen           39.2          19.6               195        4675
#>  8 Adelie  Torgersen           41.1          17.6               182        3200
#>  9 Adelie  Torgersen           38.6          21.2               191        3800
#> 10 Adelie  Torgersen           34.6          21.1               198        4400
#> # … with 190 more rows, and 8 more variables: sex <fct>, year <int>,
#> #   .fitted <dbl>, .resid <dbl>, .hat <dbl>, .sigma <dbl>, .cooksd <dbl>,
#> #   .std.resid <dbl>

# pass `penguins_te` as `newdata`
augment(lm_fit, newdata = penguins_te)
#> # A tibble: 133 × 10
#>    species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>   <fct>           <dbl>         <dbl>             <int>       <int>
#>  1 Gentoo  Biscoe           45            15.4               220        5050
#>  2 Gentoo  Biscoe           43.8          13.9               208        4300
#>  3 Gentoo  Biscoe           45.5          15                 220        5000
#>  4 Gentoo  Biscoe           43.2          14.5               208        4450
#>  5 Gentoo  Biscoe           50.4          15.3               224        5550
#>  6 Gentoo  Biscoe           45.3          13.8               208        4200
#>  7 Gentoo  Biscoe           46.2          14.9               221        5300
#>  8 Gentoo  Biscoe           45.7          13.9               214        4400
#>  9 Gentoo  Biscoe           54.3          15.7               231        5650
#> 10 Gentoo  Biscoe           45.8          14.2               219        4700
#> # … with 123 more rows, and 4 more variables: sex <fct>, year <int>,
#> #   .fitted <dbl>, .resid <dbl>

Note the missing columns in the newdata output.

re: 2) If one tries to work with parsnip’s interface to get these values, there’s no interface to data, so one must supply the training data to newdata. In results, they don’t get this fit info and the .fitted output is renamed to .pred.

augment(pars_fit, new_data = penguins_tr)
#> # A tibble: 200 × 10
#>    species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>   <fct>              <dbl>         <dbl>             <int>       <int>
#>  1 Adelie  Torgersen           39.1          18.7               181        3750
#>  2 Adelie  Torgersen           39.5          17.4               186        3800
#>  3 Adelie  Torgersen           40.3          18                 195        3250
#>  4 Adelie  Torgersen           36.7          19.3               193        3450
#>  5 Adelie  Torgersen           39.3          20.6               190        3650
#>  6 Adelie  Torgersen           38.9          17.8               181        3625
#>  7 Adelie  Torgersen           39.2          19.6               195        4675
#>  8 Adelie  Torgersen           41.1          17.6               182        3200
#>  9 Adelie  Torgersen           38.6          21.2               191        3800
#> 10 Adelie  Torgersen           34.6          21.1               198        4400
#> # … with 190 more rows, and 4 more variables: sex <fct>, year <int>,
#> #   .pred <dbl>, .resid <dbl>

augment(pars_fit, data = penguins_tr, new_data = penguins_tr)
#> # A tibble: 200 × 10
#>    species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>   <fct>              <dbl>         <dbl>             <int>       <int>
#>  1 Adelie  Torgersen           39.1          18.7               181        3750
#>  2 Adelie  Torgersen           39.5          17.4               186        3800
#>  3 Adelie  Torgersen           40.3          18                 195        3250
#>  4 Adelie  Torgersen           36.7          19.3               193        3450
#>  5 Adelie  Torgersen           39.3          20.6               190        3650
#>  6 Adelie  Torgersen           38.9          17.8               181        3625
#>  7 Adelie  Torgersen           39.2          19.6               195        4675
#>  8 Adelie  Torgersen           41.1          17.6               182        3200
#>  9 Adelie  Torgersen           38.6          21.2               191        3800
#> 10 Adelie  Torgersen           34.6          21.1               198        4400
#> # … with 190 more rows, and 4 more variables: sex <fct>, year <int>,
#> #   .pred <dbl>, .resid <dbl>

I’m on board for guardrails here, but think we can be more accommodating and use prompts to encourage the kind of behavior we want to see.

I think a better approach here could be to allow for not passing new_data, but warn/message about it:

i.e. give this output on augment(pars_fit):

augment(lm_fit)
#> # A tibble: 200 × 8
#>    body_mass_g flipper_length_… .fitted .resid    .hat .sigma .cooksd .std.resid
#>          <int>            <int>   <dbl>  <dbl>   <dbl>  <dbl>   <dbl>      <dbl>
#>  1        3750              181   3273.  477.  0.0127    415. 8.58e-3     1.16  
#>  2        3800              186   3523.  277.  0.00862   416. 1.96e-3     0.671 
#>  3        3250              195   3972. -722.  0.00511   413. 7.81e-3    -1.74  
#>  4        3450              193   3872. -422.  0.00547   415. 2.85e-3    -1.02  
#>  5        3650              190   3722.  -72.2 0.00645   416. 9.89e-5    -0.175 
#>  6        3625              181   3273.  352.  0.0127    415. 4.67e-3     0.853 
#>  7        4675              195   3972.  703.  0.00511   413. 7.42e-3     1.70  
#>  8        3200              182   3323. -123.  0.0117    416. 5.31e-4    -0.299 
#>  9        3800              191   3772.   27.9 0.00606   416. 1.39e-5     0.0675
#> 10        4400              198   4121.  279.  0.00504   416. 1.15e-3     0.674 
#> # … with 190 more rows

with the prompt:

#> Adding information about the model fit using the data that was used to fit the model. Predict with new data using the `new_data` argument to assess predictive performance.

If the user passes the training data to new_data, we can detect it with model.frame(x), and warn then too:

#> The training data was passed as `new_data`. Model predictions may appear overly performant; please interpret cautiously. See `?repredicting` to learn more.

Where we could write a dedicated doc topic on predicting on the training data.

Let me know if yall would welcome a PR here. :)

Created on 2022-05-11 by the reprex package (v2.0.1)

simonpcouch avatar May 11 '22 17:05 simonpcouch

Given that glance() will happily report R-squared on these data (which relies on re-predicting, after all), augment() could work as proposed without any message as well, for consistency.

This does mean dropping guardrails further so there might be resistance to that. I'm not opposed to messaging, just worry about messaging before learners are ready to interpret the message.

mine-cetinkaya-rundel avatar May 11 '22 17:05 mine-cetinkaya-rundel