grf icon indicating copy to clipboard operation
grf copied to clipboard

Add permutation importance for GRFs

Open dswatson opened this issue 7 years ago • 6 comments

Following randomForest and ranger, it would be nice to implement Breiman & Cutler's permutation importance metric. This would just require calculating the mean difference in OOB error between a feature's real and permuted observations across all trees. This value may optionally be scaled by the original feature's standard deviation to get something like a z-score.

dswatson avatar Nov 08 '17 16:11 dswatson

For a little background: Breiman & Cutler introduced permutation importance in their random forest classification manual. Strobl & Zeileis wrote an interesting article a few years later advocating the unscaled version over the scaled.

It would be fairly simple to code this in R if we could make predictions with a grf_tree object, but I don't see how to do that? In any event, it would certainly be more efficient in C++. It looks like the authors of ranger have done this in their package, but I'm not sure how easy it would be to incorporate that work into what's already been built here. If anyone's curious, you can view their implementation by searching for "IMP_PERM" in Forest.cpp.

dswatson avatar Nov 09 '17 14:11 dswatson

An alternative is to permute the dependent variable a number of times to obtain a null distribution for the importance measure, and then obtain a p-value. This approach is used by Altmann et al. (2010) for random forests, and by Bleich et al. (2014) for BART.

I include a simple example for grf below, where none of the variables are truly important. The continuous variable generally has a greater variable importance than the binary variable, because there are more potential splitting points, and similarly the variable with four categories generally has a greater variable importance than the binary variable. For variables with more categories, there are more potential splitting variables, because a categorical variable is entered as a set of binary variables for each category. However, for the variable with 20 categories, relatively few observations are in any one category, and therefore a split on the binary variable for one of these categories is unlikely to lead to a large improvement in the splitting criterion. Therefore, the variable importances are small for the variable with 20 categories. This is related to issue #310 . If grf supported categorical variables, then x5 would have greater variable importance than x4.

The p-values appear to be unaffected by these biases. I include the code for obtaining the p-values for three different simulations in a separate comment below. Note that this code can take several hours to run.

Altmann, A., Toloşi, L., Sander, O., & Lengauer, T. (2010). Permutation importance: a corrected feature importance measure. Bioinformatics, 26(10), 1340-1347. Bleich, J., Kapelner, A., George, E. I., & Jensen, S. T. (2014). Variable selection for BART: an application to gene regulation. The Annals of Applied Statistics, 1750-1781.


library(grf)
library(data.table)
library(ggplot2)
set.seed(11)

true_var_split_imps_null <- list()
######
outer_iterations <- 100
start.time <- Sys.time()
for (k in 1:outer_iterations){
  #########################################################################################
  N <- 500
  
  x1 <- rnorm(n = N) 
  x2 <- sample( 0:1, size= N, replace = TRUE, prob = rep(1,2)/2)
  x3 <- sample( 0:3, size= N, replace = TRUE, prob = rep(1,4)/4)
  x4 <- sample( 0:9, size= N, replace = TRUE, prob = rep(1,10)/10)
  x5 <- sample( 0:19, size= N, replace = TRUE, prob = rep(1,20)/20)
  x2_factor <- as.factor(x2)
  x3_factor <- as.factor(x3)
  x4_factor <- as.factor(x4)
  x5_factor <- as.factor(x5)
  
  Treatment_var <- sample( 0:1, size= N, replace = TRUE, prob = rep(1,2)/2)
  
  y_null<- rnorm(N)

  X_cov <- cbind(x1,
                 model.matrix( ~ x2_factor +0 ),
                 model.matrix( ~ x3_factor +0 ),
                 model.matrix( ~ x4_factor +0 ),
                 model.matrix( ~ x5_factor +0 ))

############################################################################################################################

  tau.forest = causal_forest(as.matrix(X_cov),
                             y_null,
                             Treatment_var, 
                             num.trees = 5000,
                             sample.fraction = 0.5,
                             mtry = floor(ncol(X_cov)/3),
                             min.node.size = 5,
                             honesty = TRUE,
                             ci.group.size = 2)
  
  var_split_imps <-   variable_importance(tau.forest)
  rownames(var_split_imps) <- colnames(X_cov)
  
  var_split_imps_sum_cats <- c(var_split_imps[1], 
                               sum(var_split_imps[2:3 ]),
                               sum(var_split_imps[4:7 ]),
                               sum(var_split_imps[8:17 ]),
                               sum(var_split_imps[18:37 ]))
  
  names(var_split_imps_sum_cats) <- c("x1","x2","x3","x4","x5")
  true_var_split_imps_null[[k]] <- var_split_imps_sum_cats
  
  ###############################################################
}

end.time <- Sys.time()
time.taken <- end.time - start.time
time.taken

#convert lists to matrices
mat_true_var_split_imps_null <- do.call(cbind, true_var_split_imps_null )

boxplot(mat_true_var_split_imps_null , use.cols = FALSE)
title("grf var. imp., simulation 1, 100 iterations")

EoghanONeill avatar Oct 18 '18 16:10 EoghanONeill

I include below some code for a simulation study for variable importance tests based on permutation of the dependent variable. This code takes several hours to run. Perhaps the authors of the grf package can provide a faster implementation of this test? Further simulation studies might be required. In the simulation 1, none of the variables are truly important. In simulations 2 and 3, the binary variable is important.


library(grf)
library(data.table)
library(ggplot2)
set.seed(11)

#create lists in which to put p-values
pvals_var_split_imps_null <- list()
pvals_var_split_imps_AI_1 <- list()
pvals_var_split_imps_AI_2 <- list()

true_var_split_imps_null <- list()
true_var_split_imps_AI_1 <- list()
true_var_split_imps_AI_2 <- list()
######
#Outer for loop (obtain outer_iterations p-values for boxplot)
outer_iterations <- 100
start.time <- Sys.time()

for (k in 1:outer_iterations){
  #########################################################################################
  N <- 500
  
  x1 <- rnorm(n = N) 
  x2 <- sample( 0:1, size= N, replace = TRUE, prob = rep(1,2)/2)
  x3 <- sample( 0:3, size= N, replace = TRUE, prob = rep(1,4)/4)
  x4 <- sample( 0:9, size= N, replace = TRUE, prob = rep(1,10)/10)
  x5 <- sample( 0:19, size= N, replace = TRUE, prob = rep(1,20)/20)
  x2_factor <- as.factor(x2)
  x3_factor <- as.factor(x3)
  x4_factor <- as.factor(x4)
  x5_factor <- as.factor(x5)
  
  Treatment_var <- sample( 0:1, size= N, replace = TRUE, prob = rep(1,2)/2)
  
  y_null<- rnorm(N)
  y_AI_1 <- 0.5*(2*Treatment_var-1)*x2+rnorm(N)
  y_AI_2 <- 0.5*x1+x2+ 0.5*(2*Treatment_var-1)*x2 + rnorm(N)
 
 
  X_cov <- cbind(x1,
                 model.matrix( ~ x2_factor +0 ),
                 model.matrix( ~ x3_factor +0 ),
                 model.matrix( ~ x4_factor +0 ),
                 model.matrix( ~ x5_factor +0 ))
  
  
  ######################################################################################################################
  #####################################################################################################################
  ############################################################################################################################

  tau.forest = causal_forest(as.matrix(X_cov),
                             y_null,
                             Treatment_var, 
                             num.trees = 5000,
                             sample.fraction = 0.5,
                             mtry = floor(ncol(X_cov)/3),
                             min.node.size = 5,
                             honesty = TRUE,
                             ci.group.size = 2)
  
  var_split_imps <-   variable_importance(tau.forest)
  rownames(var_split_imps) <- colnames(X_cov)
  
  var_split_imps_sum_cats <- c(var_split_imps[1], 
                               sum(var_split_imps[2:3 ]),
                               sum(var_split_imps[4:7 ]),
                               sum(var_split_imps[8:17 ]),
                               sum(var_split_imps[18:37 ]))
  
  names(var_split_imps_sum_cats) <- c("x1","x2","x3","x4","x5")
  true_var_split_imps_null[[k]] <- var_split_imps_sum_cats
  
  ###############################################################
  
  num_iter <- 100
  perm_var_split_imps <- list()
  
  for (j in 1:num_iter){
    permutedY <- sample(x = y_null, size = length(y_null), replace = FALSE, prob = NULL)
    
    tau.forest = causal_forest(as.matrix(X_cov),
                               permutedY,
                               Treatment_var, 
                               num.trees = 5000,
                               sample.fraction = 0.5,
                               mtry = floor(ncol(X_cov)/3),
                               min.node.size = 5,
                               honesty = TRUE,
                               ci.group.size = 2
    )
    perm_var_split_imps[[j]] <-   variable_importance(tau.forest)
    rownames(perm_var_split_imps[[j]]) <- colnames(X_cov)
    gc()
  }
  
  mat_perm_var_split_imps <- do.call(cbind, perm_var_split_imps)
  
  mat_perm_var_split_imps_sum_cats <- rbind(mat_perm_var_split_imps[1, ], 
                                            colSums(mat_perm_var_split_imps[2:3, ], na.rm = FALSE, dims = 1),
                                            colSums(mat_perm_var_split_imps[4:7, ], na.rm = FALSE, dims = 1),
                                            colSums(mat_perm_var_split_imps[8:17, ], na.rm = FALSE, dims = 1),
                                            colSums(mat_perm_var_split_imps[18:37, ], na.rm = FALSE, dims = 1))
  
  rownames(mat_perm_var_split_imps_sum_cats) <- c("x1","x2","x3","x4","x5")
  
  test_bind <- cbind(var_split_imps_sum_cats,mat_perm_var_split_imps_sum_cats)
  pvals_var_split_imps_null[[k]] <-  apply(test_bind, 1, function(x) sum(x[1] < x[2:ncol(test_bind)])/num_iter)
  
  ######################################################################################################################
  #####################################################################################################################
  ############################################################################################################################

  tau.forest = causal_forest(as.matrix(X_cov),
                             y_AI_1,
                             Treatment_var, 
                             num.trees = 5000,
                             sample.fraction = 0.5,
                             mtry = floor(ncol(X_cov)/3),
                             min.node.size = 5,
                             honesty = TRUE,
                             ci.group.size = 2)
  
  var_split_imps <-   variable_importance(tau.forest)
  rownames(var_split_imps) <- colnames(X_cov)
  
  var_split_imps_sum_cats <- c(var_split_imps[1], 
                               sum(var_split_imps[2:3 ]),
                               sum(var_split_imps[4:7 ]),
                               sum(var_split_imps[8:17 ]),
                               sum(var_split_imps[18:37 ]))
  
  names(var_split_imps_sum_cats) <- c("x1","x2","x3","x4","x5")
  true_var_split_imps_AI_1[[k]] <- var_split_imps_sum_cats
  
  
  ###############################################################
  
  num_iter <- 100
  perm_var_split_imps <- list()
  
  for (j in 1:num_iter){
    permutedY <- sample(x = y_AI_1, size = length(y_AI_1), replace = FALSE, prob = NULL)
    
    tau.forest = causal_forest(as.matrix(X_cov),
                               permutedY,
                               Treatment_var, 
                               num.trees = 5000,
                               sample.fraction = 0.5,
                               mtry = floor(ncol(X_cov)/3),
                               min.node.size = 5,
                               honesty = TRUE,
                               ci.group.size = 2
    )
    perm_var_split_imps[[j]] <-   variable_importance(tau.forest)
    rownames(perm_var_split_imps[[j]]) <- colnames(X_cov)
    gc()
  }
  
  
  mat_perm_var_split_imps <- do.call(cbind, perm_var_split_imps)
  
  mat_perm_var_split_imps_sum_cats <- rbind(mat_perm_var_split_imps[1, ], 
                                            colSums(mat_perm_var_split_imps[2:3, ], na.rm = FALSE, dims = 1),
                                            colSums(mat_perm_var_split_imps[4:7, ], na.rm = FALSE, dims = 1),
                                            colSums(mat_perm_var_split_imps[8:17, ], na.rm = FALSE, dims = 1),
                                            colSums(mat_perm_var_split_imps[18:37, ], na.rm = FALSE, dims = 1))
  
  rownames(mat_perm_var_split_imps_sum_cats) <- c("x1","x2","x3","x4","x5")
  
  test_bind <- cbind(var_split_imps_sum_cats,mat_perm_var_split_imps_sum_cats)
  pvals_var_split_imps_AI_1[[k]] <-  apply(test_bind, 1, function(x) sum(x[1] < x[2:ncol(test_bind)])/num_iter)
  
  ######################################################################################################################
  #####################################################################################################################
  ############################################################################################################################
  
  tau.forest = causal_forest(as.matrix(X_cov),
                             y_AI_2,
                             Treatment_var, 
                             num.trees = 5000,
                             sample.fraction = 0.5,
                             mtry = floor(ncol(X_cov)/3),
                             min.node.size = 5,
                             honesty = TRUE,
                             ci.group.size = 2)
  
  var_split_imps <-   variable_importance(tau.forest)
  rownames(var_split_imps) <- colnames(X_cov)
  
  var_split_imps_sum_cats <- c(var_split_imps[1], 
                               sum(var_split_imps[2:3 ]),
                               sum(var_split_imps[4:7 ]),
                               sum(var_split_imps[8:17 ]),
                               sum(var_split_imps[18:37 ]))
  
  names(var_split_imps_sum_cats) <- c("x1","x2","x3","x4","x5")
  true_var_split_imps_AI_2[[k]] <- var_split_imps_sum_cats
  
  
  ###############################################################
  
  num_iter <- 100
  perm_var_split_imps <- list()
  
  for (j in 1:num_iter){
    permutedY <- sample(x = y_AI_2, size = length(y_AI_2), replace = FALSE, prob = NULL)
    
    tau.forest = causal_forest(as.matrix(X_cov),
                               permutedY,
                               Treatment_var, 
                               num.trees = 5000,
                               sample.fraction = 0.5,
                               mtry = floor(ncol(X_cov)/3),
                               min.node.size = 5,
                               honesty = TRUE,
                               ci.group.size = 2
    )
    perm_var_split_imps[[j]] <-   variable_importance(tau.forest)
    rownames(perm_var_split_imps[[j]]) <- colnames(X_cov)
    gc()
  }
  
  
  mat_perm_var_split_imps <- do.call(cbind, perm_var_split_imps)
  
  mat_perm_var_split_imps_sum_cats <- rbind(mat_perm_var_split_imps[1, ], 
                                            colSums(mat_perm_var_split_imps[2:3, ], na.rm = FALSE, dims = 1),
                                            colSums(mat_perm_var_split_imps[4:7, ], na.rm = FALSE, dims = 1),
                                            colSums(mat_perm_var_split_imps[8:17, ], na.rm = FALSE, dims = 1),
                                            colSums(mat_perm_var_split_imps[18:37, ], na.rm = FALSE, dims = 1))
  
  rownames(mat_perm_var_split_imps_sum_cats) <- c("x1","x2","x3","x4","x5")
  
  test_bind <- cbind(var_split_imps_sum_cats,mat_perm_var_split_imps_sum_cats)
  pvals_var_split_imps_AI_2[[k]]  <-  apply(test_bind, 1, function(x) sum(x[1] < x[2:ncol(test_bind)])/num_iter)
  
  ######################################################################################################################
  #####################################################################################################################
  ############################################################################################################################
  
  
}

end.time <- Sys.time()
time.taken <- end.time - start.time
time.taken

#convert lists to matrices
mat_pvals_var_split_imps_null  <- do.call(cbind, pvals_var_split_imps_null )
mat_pvals_var_split_imps_AI_1  <- do.call(cbind, pvals_var_split_imps_AI_1 )
mat_pvals_var_split_imps_AI_2  <- do.call(cbind, pvals_var_split_imps_AI_2 )
mat_true_var_split_imps_null <- do.call(cbind, true_var_split_imps_null )
mat_true_var_split_imps_AI_1 <- do.call(cbind, true_var_split_imps_AI_1 )
mat_true_var_split_imps_AI_2 <- do.call(cbind, true_var_split_imps_AI_2 )

rate_correct_pvals_var_split_imps_null <- sum(apply(mat_pvals_var_split_imps_null, 2, FUN = function(x){ if(which.min(x)==2) 1 else 0}))/num_iter 
rate_correct_pvals_var_split_imps_AI_1 <- sum(apply(mat_pvals_var_split_imps_AI_1, 2, FUN = function(x){ if(which.min(x)==2) 1 else 0}))/num_iter 
rate_correct_pvals_var_split_imps_AI_2 <- sum(apply(mat_pvals_var_split_imps_AI_2, 2, FUN = function(x){ if(which.min(x)==2) 1 else 0}))/num_iter 

rate_correct_true_var_split_imps_null <- sum(apply(mat_true_var_split_imps_null, 2, FUN = function(x){ if(which.max(x)==2) 1 else 0}))/num_iter 
rate_correct_true_var_split_imps_AI_1 <- sum(apply(mat_true_var_split_imps_AI_1, 2, FUN = function(x){ if(which.max(x)==2) 1 else 0}))/num_iter 
rate_correct_true_var_split_imps_AI_2 <- sum(apply(mat_true_var_split_imps_AI_2, 2, FUN = function(x){ if(which.max(x)==2) 1 else 0}))/num_iter 

boxplot(mat_pvals_var_split_imps_null, use.cols = FALSE)
title("Bleich local test p-values, simulation 1")

boxplot(mat_pvals_var_split_imps_AI_1, use.cols = FALSE)
title("Bleich local test p-values, simulation 2")

boxplot(mat_pvals_var_split_imps_AI_2, use.cols = FALSE)
title("Bleich local test p-values, simulation 3")

#################################

boxplot(mat_true_var_split_imps_null , use.cols = FALSE)
title("grf var. imp., simulation 1, 100 iterations")

boxplot(mat_true_var_split_imps_AI_1 , use.cols = FALSE)
title("grf var. imp., simulation 2, 100 iterations")

boxplot(mat_true_var_split_imps_AI_2 , use.cols = FALSE)
title("grf var. imp., simulation 3, 100 iterations")

EoghanONeill avatar Oct 18 '18 16:10 EoghanONeill

Has any progress been made on this front? It would be immensely helpful to have a variable importance measure that isn't sensitive to the total number of cut-points a variable has.

DeFilippis avatar Feb 04 '20 05:02 DeFilippis

This is purely based on intuition and not any sort of rigorous mathematical insight, but given that the splitting criteria for a casual forest is based on the estimated treatment effect rather than the mean of the outcome would permuting treatment assignment be more relevant here? in calculating variable importance under the sharp null of no effect?

boyercb avatar Mar 19 '20 22:03 boyercb

I think permuting treatment assignment is a good suggestion. It gives similar results to the simulations above.

Also, the measure based on permutation of the dependent variable should be interpreted as a corrected importance measure, and not a valid p-value for a hypothesis test. See Nembrini (2019) for more details.

Nembrini, S. (2019). On what to permute in test-based approaches for variable importance measures in Random Forests. Bioinformatics, 35(15), 2701-2705.

EoghanONeill avatar Jul 10 '20 19:07 EoghanONeill