tidypredict
tidypredict copied to clipboard
Ranger produces list of trees
Hello, thanks for your work on this package, it is very exciting! I was trying to to follow the docs on using a ranger
RF model, but it seems to return a list of trees/case_when
s rather than one statement. Is it intended we execute all the trees on the DB then calculate the prediction from the results? I don't get that impression from the docs. Thanks!
library(ranger)
library(tidypredict)
library(dplyr, warn.conflicts = FALSE)
test_mod <- ranger(Species ~ ., iris, num.trees = 100)
trees <- tidypredict_fit(test_mod)
# Is list of trees
str(trees, max.level = 1, list.len = 3)
#> List of 100
#> $ : language case_when(Petal.Width < 0.8 ~ "setosa", Sepal.Length < 5.75 & Petal.Width >= 0.8 ~ "versicolor", Petal.Width| __truncated__ ...
#> $ : language case_when(Petal.Length < 2.45 ~ "setosa", Petal.Width >= 1.7 & Petal.Length >= 2.45 ~ "virginica", Petal.Len| __truncated__ ...
#> $ : language case_when(Petal.Width < 0.8 ~ "setosa", Petal.Length < 4.9 & Petal.Width < 1.75 & Petal.Width >= 0.8 ~ "vers| __truncated__ ...
#> [list output truncated]
# One example
trees[[1]]
#> case_when(Petal.Width < 0.8 ~ "setosa", Sepal.Length < 5.75 &
#> Petal.Width >= 0.8 ~ "versicolor", Petal.Width >= 1.75 &
#> Sepal.Length >= 5.75 & Petal.Width >= 0.8 ~ "virginica",
#> Petal.Length < 4.75 & Sepal.Width < 2.25 & Petal.Width <
#> 1.75 & Sepal.Length >= 5.75 & Petal.Width >= 0.8 ~ "versicolor",
#> Petal.Length >= 4.75 & Sepal.Width < 2.25 & Petal.Width <
#> 1.75 & Sepal.Length >= 5.75 & Petal.Width >= 0.8 ~ "virginica",
#> Petal.Width < 1.55 & Sepal.Width >= 2.25 & Petal.Width <
#> 1.75 & Sepal.Length >= 5.75 & Petal.Width >= 0.8 ~ "versicolor",
#> Petal.Width >= 1.65 & Petal.Width >= 1.55 & Sepal.Width >=
#> 2.25 & Petal.Width < 1.75 & Sepal.Length >= 5.75 & Petal.Width >=
#> 0.8 ~ "versicolor", Petal.Length < 5.45 & Petal.Width <
#> 1.65 & Petal.Width >= 1.55 & Sepal.Width >= 2.25 & Petal.Width <
#> 1.75 & Sepal.Length >= 5.75 & Petal.Width >= 0.8 ~ "versicolor",
#> Petal.Length >= 5.45 & Petal.Width < 1.65 & Petal.Width >=
#> 1.55 & Sepal.Width >= 2.25 & Petal.Width < 1.75 & Sepal.Length >=
#> 5.75 & Petal.Width >= 0.8 ~ "virginica")
# Suggested by old issue doesn't work
iris %>%
tidypredict_to_column(test_mod)
#> Error in tidypredict_to_column(., test_mod): tidypredict_to_column does not support tree based models
Created on 2020-08-23 by the reprex package (v0.3.0)
I looked at the documentation and agree that it needs to be revised.
I think that the intention was to do some dplyr
work to get the predictions in the format that you might want.
Here's some code that uses dplyr
, purrr
, and tidyr
:
library(ranger)
library(tidypredict)
library(dplyr, warn.conflicts = FALSE)
test_mod <- ranger(Species ~ ., iris, num.trees = 100)
trees <- tidypredict_fit(test_mod)
new_samples <- iris[c(1, 51, 101), ]
votes <-
purrr:::map_dfr(trees,
~ tibble(.pred = rlang::eval_tidy(.x, new_samples),
.row = 1:nrow(new_samples)
)
)
class_pred <-
votes %>%
group_by(.row) %>%
count(.pred) %>%
slice_max(n) %>%
ungroup() %>%
select(-n)
class_pred
#> # A tibble: 3 x 2
#> .row .pred
#> <int> <chr>
#> 1 1 setosa
#> 2 2 versicolor
#> 3 3 virginica
class_prob <-
votes %>%
group_by(.row) %>%
count(.pred) %>%
mutate(prob = n/100) %>%
ungroup() %>%
select(-n) %>%
tidyr::pivot_wider(id_cols = ".row", names_from = ".pred", values_from = "prob", values_fill = 0)
class_prob
#> # A tibble: 3 x 4
#> .row setosa versicolor virginica
#> <int> <dbl> <dbl> <dbl>
#> 1 1 1 0 0
#> 2 2 0 0.98 0.02
#> 3 3 0 0 1
Created on 2020-12-04 by the reprex package (v0.3.0)