grf
grf copied to clipboard
Summarizing HTE Outputs in a Multi-Arm Experiment
Hello, I need some guidance/direction/suggestions on how can I use the estimated HTE outputs from the multi_arm_causal_forest to create insightful summary. After going through this paper, I can think of some approaches. But I am a bit confused, since these resources discussed about binary treatment only, whereas my usecase is “multi-arm treatment”.
Lets consider a reproducible example to discuss the approaches,
Setups
library(grf)
library(dplyr)
library(ggplot2)
set.seed(1344)
Helper fns
predict_effect_and_ci <- function(multi_arm_causal_forest_model, newdata = NULL) {
if (!inherits(multi_arm_causal_forest_model, "multi_arm_causal_forest")) {
stop('This function only supports model objects of class "multi_arm_causal_forest".')
}
tau_hat <- predict(
multi_arm_causal_forest_model,
newdata = newdata,
estimate.variance = TRUE,
drop = TRUE
)
effect_estimate_df <- as.data.frame(tau_hat$predictions)
contrasts_name <- colnames(effect_estimate_df)
contrast_generic_name <- paste0("contrast_", seq(1, length(contrasts_name)))
contrast_info <- setNames(contrasts_name, contrast_generic_name)
colnames(effect_estimate_df) <- paste0(contrast_generic_name, "_estimate")
effect_estimate_var_df <- as.data.frame(tau_hat$variance.estimates)
colnames(effect_estimate_var_df) <- paste0(contrast_generic_name, "_var")
effect_est_df <- bind_cols(effect_estimate_df, effect_estimate_var_df)
return(list(
contrast_info = contrast_info,
data = effect_est_df
))
}
get_top_n_vars <- function(forest, X, n = 3) {
varimp <- grf::variable_importance(forest)
ranked_variables <- order(varimp, decreasing = TRUE)
top_varnames <- colnames(X)[ranked_variables[1:n]]
return(top_varnames)
}
n <- 3000
p <- 10
X <- matrix(rnorm(n * p), n, p)
W <- as.factor(sample(c("A", "B", "C"), n, replace = TRUE))
Y <- X[, 1] + X[, 2] * (W == "B") - 1.5 * X[, 2] * (W == "C") + rnorm(n)
exp_df <- data.frame(Y = Y, W = W, X)
Splitting Data into Train-Test
train = sample(nrow(X), 0.6 * nrow(X))
test = -train
Fit Forest Model on Training Set
mc.forest <- multi_arm_causal_forest(X[train, ], Y[train], W[train], seed = 1344)
Predict HTEs on Test Set
tau_hat_est <- predict_effect_and_ci(mc.forest, newdata = X[test, ])
tau_hat_est_df <- bind_cols(tau_hat_est$data, exp_df[test, ]) %>%
mutate(
c1_ci_low = contrast_1_estimate - 1.96 * sqrt(contrast_1_var),
c1_ci_high = contrast_1_estimate + 1.96 * sqrt(contrast_1_var),
c2_ci_low = contrast_2_estimate - 1.96 * sqrt(contrast_2_var),
c2_ci_high = contrast_2_estimate + 1.96 * sqrt(contrast_2_var),
)
head(tau_hat_est_df, 3)
contrast_1_estimate contrast_2_estimate contrast_1_var contrast_2_var
1 -0.06310611 -0.1661815 0.01636501 0.01322940
2 -0.56899703 0.9466801 0.02781945 0.05492243
3 -0.49811420 0.9974250 0.02565718 0.06470612
Y W X1 X2 X3 X4 X5
1 0.9849406 A 0.54756844 -0.1014569 0.21716754 -2.0556520 -0.04809347
2 -2.1574977 A 0.08431498 -0.6160837 -0.46033781 -0.1537932 0.08784540
3 1.0177696 A -0.50059754 -0.6376908 -0.08594392 0.4529726 -1.98854317
X6 X7 X8 X9 X10 c1_ci_low c1_ci_high
1 1.8194223 -0.04598789 -0.3885001 0.45111597 -1.9751646 -0.3138407 0.1876284
2 -0.8248944 -1.42140442 -0.8348958 0.06918902 0.8410156 -0.8959086 -0.2420854
3 -0.5929628 0.08853166 0.1790741 0.92633845 0.8261464 -0.8120642 -0.1841642
c2_ci_low c2_ci_high
1 -0.3916190 0.05925604
2 0.4873436 1.40601655
3 0.4988520 1.49599794
Creating HTE Quartile Groups
The tau_hat_est_df contains two HTE estimates, $\hat{\tau}{b-a}$ comparing treatment “B” with “A” and $\hat{\tau}{c-a}$ comparing treatment “C” with “A”. We can create quartile groups based on $\hat{\tau}_{b-a}$, at first.
num.groups = 4
quartile = cut(
tau_hat_est_df$contrast_1_estimate,
quantile(tau_hat_est_df$contrast_1_estimate, seq(0, 1, by = 1 / num.groups)),
labels = 1:num.groups,
include.lowest = TRUE
)
samples.by.quartile = split(seq_along(quartile), quartile)
eval.forest = multi_arm_causal_forest(X[test, ], Y[test], W[test], seed = 1345)
ate.by.quartile = lapply(samples.by.quartile, function(samples) {
average_treatment_effect(eval.forest, subset = samples)
})
df.plot.ate = bind_rows(ate.by.quartile, .id = "group") %>%
mutate(
group = paste0("Q", group)
) %>%
select(group, contrast, estimate, std.err)
rownames(df.plot.ate) <- NULL
head(df.plot.ate, 10)
group contrast estimate std.err
1 Q1 B - A -1.1571475 0.1582629
2 Q1 C - A 2.0199825 0.1652687
3 Q2 B - A -0.3894375 0.1472359
4 Q2 C - A 0.5131891 0.1584275
5 Q3 B - A 0.3887451 0.1407779
6 Q3 C - A -0.4980314 0.1409159
7 Q4 B - A 1.2435839 0.1512048
8 Q4 C - A -1.9012597 0.1586695
tau_BA_ate <- df.plot.ate %>%
filter(contrast == "B - A")
tau_BA_ate %>%
ggplot(aes(x = group, y = estimate)) +
geom_hline(yintercept = 0, linetype = 2, linewidth = 0.5) +
geom_errorbar(
aes(
ymin = estimate - 1.96 * std.err,
ymax = estimate + 1.96 * std.err
),
width = 0.09, color = "#4E79A7", linewidth = 0.7
) +
geom_point(color = "#E15759", size = 3) +
xlab("Estimated CATE Quartile") +
ylab("Average treatment effect") +
theme_minimal() +
theme(
plot.title = element_text(size = 12, face = "bold", lineheight = 1.1),
axis.text = element_text(size = 11),
axis.title.x = element_text(margin = margin(t = 10))
)
Note that, since I have created the quartile groups based on $\hat{\tau}{b-a}$, I have only used the ATE estimates (and its SE) for the “B - A” contrast and plotted them, ignoring the values for “C - A” contrast. But when $\hat{\tau}{c-a}$ will be used to create the quartile groups, only the ATE estimates for “C - A” contrast will be shown. At least, that what I am thinking. So my question is, Am I on the right track? Are there any better ways ?
Covariate Profiles for Quartile Groups
top_2_vars <- get_top_n_vars(
mc.forest,
exp_df %>% select(starts_with("X")),
n = 2
)
top_2_vars
[1] "X2" "X5"
tau_hat_est_df %>%
mutate(
Q = quartile,
group = paste0("Q", Q)
) %>%
group_by(group) %>%
summarise(
across(.cols = all_of(top_2_vars), .fns = mean, .names = "mean_{.col}")
) %>%
left_join(tau_BA_ate, by = "group")
# A tibble: 4 × 6
group mean_X2 mean_X5 contrast estimate std.err
<chr> <dbl> <dbl> <chr> <dbl> <dbl>
1 Q1 -1.28 0.124 B - A -1.16 0.158
2 Q2 -0.301 -0.0852 B - A -0.389 0.147
3 Q3 0.366 -0.138 B - A 0.389 0.141
4 Q4 1.30 -0.0303 B - A 1.24 0.151
Is the above summary representation valid? Are there any better ways?
Additional Questions
- Is it incorrect to average the $\hat{\tau}_{b-a}$ for each quartile, rather than fitting
eval.forestto each quartile group separately to get the ATE estimates?
Hi @shafayetShafee,
After going through this paper, I can think of some approaches. But I am a bit confused, since these resources discussed about binary treatment only, whereas my usecase is “multi-arm treatment”.
The approaches described in that paper applies to any given treatment, and you could perform the same kind of separate analysis for each arm if you have several arms available. In principle you could fit a causal forest for each treatment arm and do this kind of analysis - you can just think of multi-arm causal forest as a way to jointly estimate all these CATEs instead of doing it separately which may increase estimation power if there is some HTE signal shared across arms.
For a kind of analysis tailored to multiple arms where there are different costs associated with deploying treatment, then this paper/package https://github.com/grf-labs/maq extends the Qini curve (Figure 5) to that setting. It essentially just allows you to translate predictions from multiple treatments into a treatment allocation policy that satisfies some budget constraint, then plot the value of this.
1: You should use the average_treatment_effect function to compute an ATE.