grf icon indicating copy to clipboard operation
grf copied to clipboard

Summarizing HTE Outputs in a Multi-Arm Experiment

Open shafayetShafee opened this issue 10 months ago • 1 comments
trafficstars

Hello, I need some guidance/direction/suggestions on how can I use the estimated HTE outputs from the multi_arm_causal_forest to create insightful summary. After going through this paper, I can think of some approaches. But I am a bit confused, since these resources discussed about binary treatment only, whereas my usecase is “multi-arm treatment”.

Lets consider a reproducible example to discuss the approaches,

Setups

library(grf)
library(dplyr)
library(ggplot2)

set.seed(1344)

Helper fns

predict_effect_and_ci <- function(multi_arm_causal_forest_model, newdata = NULL) {

  if (!inherits(multi_arm_causal_forest_model, "multi_arm_causal_forest")) {
    stop('This function only supports model objects of class "multi_arm_causal_forest".')
  }

  tau_hat <- predict(
    multi_arm_causal_forest_model,
    newdata = newdata,
    estimate.variance = TRUE,
    drop = TRUE
  )

  effect_estimate_df <- as.data.frame(tau_hat$predictions)
  contrasts_name <- colnames(effect_estimate_df)
  contrast_generic_name <- paste0("contrast_", seq(1, length(contrasts_name)))
  contrast_info <- setNames(contrasts_name, contrast_generic_name)
  colnames(effect_estimate_df) <- paste0(contrast_generic_name, "_estimate")

  effect_estimate_var_df <- as.data.frame(tau_hat$variance.estimates)
  colnames(effect_estimate_var_df) <- paste0(contrast_generic_name, "_var")

  effect_est_df <- bind_cols(effect_estimate_df, effect_estimate_var_df)

  return(list(
    contrast_info = contrast_info,
    data = effect_est_df
  ))
}

get_top_n_vars <- function(forest, X, n = 3) {
  varimp <- grf::variable_importance(forest)
  ranked_variables <- order(varimp, decreasing = TRUE)
  top_varnames <- colnames(X)[ranked_variables[1:n]]
  return(top_varnames)
}
n <- 3000
p <- 10
X <- matrix(rnorm(n * p), n, p)
W <- as.factor(sample(c("A", "B", "C"), n, replace = TRUE))
Y <- X[, 1] + X[, 2] * (W == "B") - 1.5 * X[, 2] * (W == "C") + rnorm(n)

exp_df <- data.frame(Y = Y, W = W, X)

Splitting Data into Train-Test

train = sample(nrow(X), 0.6 * nrow(X))
test = -train

Fit Forest Model on Training Set

mc.forest <- multi_arm_causal_forest(X[train, ], Y[train], W[train], seed = 1344)

Predict HTEs on Test Set

tau_hat_est <- predict_effect_and_ci(mc.forest, newdata = X[test, ])
tau_hat_est_df <- bind_cols(tau_hat_est$data, exp_df[test, ]) %>% 
  mutate(
    c1_ci_low = contrast_1_estimate - 1.96 * sqrt(contrast_1_var),
    c1_ci_high = contrast_1_estimate + 1.96 * sqrt(contrast_1_var),
    c2_ci_low = contrast_2_estimate - 1.96 * sqrt(contrast_2_var),
    c2_ci_high = contrast_2_estimate + 1.96 * sqrt(contrast_2_var),
  )

head(tau_hat_est_df, 3)
  contrast_1_estimate contrast_2_estimate contrast_1_var contrast_2_var
1         -0.06310611          -0.1661815     0.01636501     0.01322940
2         -0.56899703           0.9466801     0.02781945     0.05492243
3         -0.49811420           0.9974250     0.02565718     0.06470612
           Y W          X1         X2          X3         X4          X5
1  0.9849406 A  0.54756844 -0.1014569  0.21716754 -2.0556520 -0.04809347
2 -2.1574977 A  0.08431498 -0.6160837 -0.46033781 -0.1537932  0.08784540
3  1.0177696 A -0.50059754 -0.6376908 -0.08594392  0.4529726 -1.98854317
          X6          X7         X8         X9        X10  c1_ci_low c1_ci_high
1  1.8194223 -0.04598789 -0.3885001 0.45111597 -1.9751646 -0.3138407  0.1876284
2 -0.8248944 -1.42140442 -0.8348958 0.06918902  0.8410156 -0.8959086 -0.2420854
3 -0.5929628  0.08853166  0.1790741 0.92633845  0.8261464 -0.8120642 -0.1841642
   c2_ci_low c2_ci_high
1 -0.3916190 0.05925604
2  0.4873436 1.40601655
3  0.4988520 1.49599794

Creating HTE Quartile Groups

The tau_hat_est_df contains two HTE estimates, $\hat{\tau}{b-a}$ comparing treatment “B” with “A” and $\hat{\tau}{c-a}$ comparing treatment “C” with “A”. We can create quartile groups based on $\hat{\tau}_{b-a}$, at first.

num.groups = 4

quartile = cut(
  tau_hat_est_df$contrast_1_estimate,
  quantile(tau_hat_est_df$contrast_1_estimate, seq(0, 1, by = 1 / num.groups)),
  labels = 1:num.groups,
  include.lowest = TRUE
)

samples.by.quartile = split(seq_along(quartile), quartile)

eval.forest = multi_arm_causal_forest(X[test, ], Y[test], W[test], seed = 1345)

ate.by.quartile = lapply(samples.by.quartile, function(samples) {
  average_treatment_effect(eval.forest, subset = samples)
})

df.plot.ate = bind_rows(ate.by.quartile, .id = "group") %>% 
  mutate(
    group = paste0("Q", group)
  ) %>% 
  select(group, contrast, estimate, std.err)
  
rownames(df.plot.ate) <- NULL

head(df.plot.ate, 10)
  group contrast   estimate   std.err
1    Q1    B - A -1.1571475 0.1582629
2    Q1    C - A  2.0199825 0.1652687
3    Q2    B - A -0.3894375 0.1472359
4    Q2    C - A  0.5131891 0.1584275
5    Q3    B - A  0.3887451 0.1407779
6    Q3    C - A -0.4980314 0.1409159
7    Q4    B - A  1.2435839 0.1512048
8    Q4    C - A -1.9012597 0.1586695
tau_BA_ate <- df.plot.ate %>% 
  filter(contrast == "B - A")

tau_BA_ate %>% 
ggplot(aes(x = group, y = estimate)) +
  geom_hline(yintercept = 0, linetype = 2, linewidth = 0.5) +
  geom_errorbar(
    aes(
      ymin = estimate - 1.96 * std.err, 
      ymax = estimate + 1.96 * std.err
    ),
    width = 0.09, color = "#4E79A7", linewidth = 0.7
  ) +
  geom_point(color = "#E15759", size = 3) +
  xlab("Estimated CATE Quartile") +
  ylab("Average treatment effect") + 
  theme_minimal() +
  theme(
    plot.title = element_text(size = 12, face = "bold", lineheight = 1.1),
    axis.text = element_text(size = 11),
    axis.title.x = element_text(margin = margin(t = 10))
  ) 

image

Note that, since I have created the quartile groups based on $\hat{\tau}{b-a}$, I have only used the ATE estimates (and its SE) for the “B - A” contrast and plotted them, ignoring the values for “C - A” contrast. But when $\hat{\tau}{c-a}$ will be used to create the quartile groups, only the ATE estimates for “C - A” contrast will be shown. At least, that what I am thinking. So my question is, Am I on the right track? Are there any better ways ?

Covariate Profiles for Quartile Groups

top_2_vars <- get_top_n_vars(
  mc.forest, 
  exp_df %>% select(starts_with("X")), 
  n = 2
)

top_2_vars
[1] "X2" "X5"
tau_hat_est_df %>% 
  mutate(
    Q = quartile,
    group = paste0("Q", Q)
  ) %>% 
  group_by(group) %>% 
  summarise(
    across(.cols = all_of(top_2_vars), .fns = mean, .names = "mean_{.col}")
  ) %>% 
  left_join(tau_BA_ate, by = "group")
# A tibble: 4 × 6
  group mean_X2 mean_X5 contrast estimate std.err
  <chr>   <dbl>   <dbl> <chr>       <dbl>   <dbl>
1 Q1     -1.28   0.124  B - A      -1.16    0.158
2 Q2     -0.301 -0.0852 B - A      -0.389   0.147
3 Q3      0.366 -0.138  B - A       0.389   0.141
4 Q4      1.30  -0.0303 B - A       1.24    0.151

Is the above summary representation valid? Are there any better ways?

Additional Questions

  1. Is it incorrect to average the $\hat{\tau}_{b-a}$ for each quartile, rather than fitting eval.forest to each quartile group separately to get the ATE estimates?

shafayetShafee avatar Jan 12 '25 10:01 shafayetShafee

Hi @shafayetShafee,

After going through this paper, I can think of some approaches. But I am a bit confused, since these resources discussed about binary treatment only, whereas my usecase is “multi-arm treatment”.

The approaches described in that paper applies to any given treatment, and you could perform the same kind of separate analysis for each arm if you have several arms available. In principle you could fit a causal forest for each treatment arm and do this kind of analysis - you can just think of multi-arm causal forest as a way to jointly estimate all these CATEs instead of doing it separately which may increase estimation power if there is some HTE signal shared across arms.

For a kind of analysis tailored to multiple arms where there are different costs associated with deploying treatment, then this paper/package https://github.com/grf-labs/maq extends the Qini curve (Figure 5) to that setting. It essentially just allows you to translate predictions from multiple treatments into a treatment allocation policy that satisfies some budget constraint, then plot the value of this.

1: You should use the average_treatment_effect function to compute an ATE.

erikcs avatar Jan 19 '25 04:01 erikcs