ccf icon indicating copy to clipboard operation
ccf copied to clipboard

predict from ranger (more modern RF)

Open m-r-munroe opened this issue 3 years ago • 0 comments

I'm trying to approximately reproduce the figure from the figure1 from the arxiv paper. The CCF seems to work well with 35 trees.

This is what the reference figure looks like:
image

A decent RF tool that isn't 20 years old is ranger.

To get a decent plot using the canonical correlation forest I use this:

require(pacman)

p_load_gh("jandob/ccf")
p_load(tidyverse, 
       ranger,
       ggplot2,
       GGally)

#parameter
num_trees <- 35

# load sample dataset
data(spirals)

#split data
d_train <- spirals[1:1000, ]
d_test <- tail(spirals, 1000)



## make ccf on the data
ccf_f <- canonical_correlation_forest(class ~ ., d_train, ntree = num_trees)

# plot the decision surface of the classifier
ccf_plot <- plot_decision_surface(model = ccf_f, 
                                  X = d_test[, c("x", "y")], 
                                  Y = d_test$class,
                                  title = str_c("CCF with ", num_trees, " trees"))

print(ccf_plot)

to make this: image

When I use the following code, it doesn't work so well:

rngr_fit <- ranger(class ~ ., data = d_train, num.tree = num_trees)
rngr_plot <- plot_decision_surface(model = rngr_fit, 
                                  X = d_test[, c("x", "y")], 
                                  Y = d_test$class,
                                  title = str_c("Ranger with ", num_trees, " trees"))

print(rngr_plot)

The output is this:

> rngr_fit <- ranger(class ~ ., data = d_train, num.tree = num_trees, classification = TRUE)
> rngr_plot <- plot_decision_surface(model = rngr_fit, 
+                                   X = d_test[, c("x", "y")], 
+                                   Y = d_test$class,
+                                   title = str_c("Ranger with ", num_trees, " trees"))
Error in as.data.frame.default(x[[i]], optional = TRUE, stringsAsFactors = stringsAsFactors) : 
  cannot coerce class ‘"ranger.prediction"’ to a data.frame
> print(rngr_plot)
Error in print(rngr_plot) : object 'rngr_plot' not found

It does not like the output of the ranger prediction. This is the structure of the prediction object.

> str(predict(rngr_fit, data = d_test))
List of 5
 $ predictions              : num [1:1000] 3 3 2 3 2 3 1 2 2 1 ...
 $ num.trees                : num 35
 $ num.independent.variables: num 2
 $ num.samples              : int 1000
 $ treetype                 : chr "Classification"
 - attr(*, "class")= chr "ranger.prediction"

This is the structure of the ranger$predictions object:

> str(predict(rngr_fit, data = d_test)$predictions %>% as.character())
 chr [1:1000] "3" "3" "2" "3" "2" "3" "1" "2" "2" "1" "1" "3" "2" "3" "3" "3" "1" "2" "3" "1" "1" "2" "2" "1" "3" "1" "3" ...

it looks a lot like the output of the ccf output object.

> str(predict(ccf_f, newdata = d_test))
 chr [1:1000] "3" "3" "2" "3" "2" "3" "1" "2" "2" "1" "1" "3" "2" "3" "3" "3" "1" "2" "3" "1" "1" "2" "2" "1" "3" "1" "3" ...

When I look at what R is doing wit the plot_decision_surface I get this:

function (model, X, Y, title = NULL, interpolate = FALSE, ...) 
{
   data <- data.frame(x = X[, 1], y = X[, 2], z = Y)
   x_min <- min(data$x) * 1.2
   x_max <- max(data$x) * 1.2
   y_min <- min(data$y) * 1.2
   y_max <- max(data$y) * 1.2
   resolution <- 400
   grid <- expand.grid(x = seq(x_min, x_max, length.out = resolution), 
      y = seq(y_min, y_max, length.out = resolution))
   predictions <- predict(model, grid, ...)
   data_raster <- data.frame(x = grid$x, y = grid$y, z = predictions)
   plot_object <- generate_2d_data_plot(data, data_raster, 
      interpolate = interpolate, title = title)
   return(plot_object)
}

If you modify it to this then it might work.

   if (class(model) != "ranger"){
   predictions <- predict(model, grid, ...)
   } else {
   predictions <- as.character(predict(model, grid, ...)$predictions)   
   }

If it can make a ggplot object then I think I can use GGally to make a grid like figure 1. I think this could be an interesting and useful diagnostic in a scatterplot matrix style plot, perhaps similar to GGally::ggpairs.

m-r-munroe avatar Dec 08 '21 20:12 m-r-munroe