planning icon indicating copy to clipboard operation
planning copied to clipboard

Transfer Learning via tidymodels

Open uriahf opened this issue 3 years ago • 7 comments

FIrst of all, thank you everyone for your hard work. I'm addicted to tidymodels and a heavy user since I discovered it.

I'm not sure if that's the right place, but I wonder if you ever thought about implementing transfer model via tidymodels workflow?

uriahf avatar Mar 03 '22 19:03 uriahf

Hello @uriahf 👋 Thank you for your interest in tidymodels!

could you give us some examples or methods that you wish to use but is unable to do within the current tidymodels set of packages?

EmilHvitfeldt avatar Mar 03 '22 20:03 EmilHvitfeldt

I will look for good reproducible example, it might take a while.

Thank you!

uriahf avatar Mar 07 '22 09:03 uriahf

I don't think we specifically need a reproducible example (we don't need a reprex) but if you can point us to specifically the method or type of analysis you want to do, that would be super helpful!

juliasilge avatar Mar 07 '22 14:03 juliasilge

Well, I made a reprex anyway 😅

I have an important feature which is missing (not at random) and I want to use it directly: I don't want to use imputation because that might mess things up with the explainability of the model and I definitely don't want to use categorization because of the loss of information.

I think that the titanic data set might be a good example: Age is an important feature which is missing for some observations.

The workflow goes like this:

1. Training a "thin model" on the train set - for all observations, only with the features that are full. (Fare and Sex).

2. Adding predictions from the "thin model" as a separate predictor in the train set and the test set thin_model_preds (Does not contain missing values).

3. Training a "full model" on the train set - only for observations with Age, using Fare, Sex (like in the thin model), Age and the predictions from the "thin model" thin_model_preds.

4. Adding predictions from the "full model" to both train and test sets full_model_preds (Contain missing values for observations without age).

5. Choosing predictions from the thin model thin_model_preds for observations without Age and observations from the full model full_model_preds for observations with Age as final predictions final_preds.

I hope that's clear enough.

Here is the reprex:

library(titanic)
library(magrittr)

data(titanic_train)
data(titanic_test)

titanic_train <- titanic_train %>% dplyr::select(Age, Fare, Sex, Survived)
titanic_test <- titanic_test %>% dplyr::select(Age, Fare, Sex)

# 1. Training a thin model for all observations

thin_model <- glm(
  Survived ~ Fare + Sex,
  data = titanic_train,
  family = "binomial"
)

# 2. Adding thin model predictions to all observations in train and test sets

titanic_train_with_thin_model_predictions <- predict(thin_model,
  type = "response",
  newdata = titanic_train
) %>%
  tibble::tibble("thin_model_preds" = .) %>%
  dplyr::bind_cols(titanic_train)

titanic_test_with_thin_model_predictions <- predict(thin_model,
  type = "response",
  newdata = titanic_test
) %>%
  tibble::tibble("thin_model_preds" = .) %>%
  dplyr::bind_cols(titanic_test)


# 3. Training full model only for observations with age
# while using age and thin model predictions as predictors

full_model <- glm(
  Survived ~ Fare + Sex + Age + thin_model_preds,
  data = titanic_train_with_thin_model_predictions,
  family = "binomial"
)
#> Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred

# 4. Adding full model predictions to observations with age in train and test sets

titanic_train_with_predictions <- predict(full_model,
  type = "response",
  newdata = titanic_train_with_thin_model_predictions
) %>%
  tibble::tibble("full_model_preds" = .) %>%
  dplyr::bind_cols(titanic_train_with_thin_model_predictions)

titanic_test_with_predictions <- predict(full_model,
  type = "response",
  newdata = titanic_test_with_thin_model_predictions
) %>%
  tibble::tibble("full_model_preds" = .) %>%
  dplyr::bind_cols(titanic_test_with_thin_model_predictions) 

# 5. Full model predictions as final predictions for observations with age 
# and Thin model predictions as final predictions for observations without age

titanic_train <- titanic_train_with_predictions %>%
  dplyr::mutate(final_preds = ifelse(is.na(full_model_preds), thin_model_preds, 
                                     full_model_preds)) %>% 
  dplyr::select(-c(full_model_preds, thin_model_preds))


titanic_test <- titanic_test_with_predictions %>%
  dplyr::mutate(final_preds = ifelse(is.na(full_model_preds), thin_model_preds, 
                                     full_model_preds)) %>% 
  dplyr::select(-c(full_model_preds, thin_model_preds))

titanic_train
#> # A tibble: 891 x 5
#>      Age  Fare Sex    Survived final_preds
#>    <dbl> <dbl> <chr>     <int>       <dbl>
#>  1    22  7.25 male          0       0.152
#>  2    38 71.3  female        1       0.817
#>  3    26  7.92 female        1       0.687
#>  4    35 53.1  female        1       0.756
#>  5    35  8.05 male          0       0.136
#>  6    NA  8.46 male          0       0.157
#>  7    54 51.9  male          0       0.277
#>  8     2 21.1  male          0       0.261
#>  9    27 11.1  female        1       0.686
#> 10    14 30.1  female        1       0.742
#> # ... with 881 more rows

titanic_test
#> # A tibble: 418 x 4
#>      Age  Fare Sex    final_preds
#>    <dbl> <dbl> <chr>        <dbl>
#>  1  34.5  7.83 male         0.135
#>  2  47    7    female       0.629
#>  3  62    9.69 male         0.107
#>  4  27    8.66 male         0.150
#>  5  22   12.3  female       0.700
#>  6  14    9.22 male         0.174
#>  7  30    7.63 female       0.676
#>  8  26   29    male         0.247
#>  9  18    7.23 female       0.707
#> 10  21   24.2  male         0.234
#> # ... with 408 more rows

uriahf avatar Mar 08 '22 09:03 uriahf

I'm going to move this to our planning repo so we can collect/track interest in an approach like this.

juliasilge avatar Mar 09 '22 16:03 juliasilge

Thank you so much!

uriahf avatar Mar 09 '22 16:03 uriahf

I notice that there is already an R package for implementing transfer learning: glmtrans. I'm not sure if this would make incorporating transfer learning in tidymodels a bit easier.

caimiao0714 avatar Mar 10 '23 14:03 caimiao0714