policytree
policytree copied to clipboard
Policy tree - double robust scores and rewards
In layman's terms, can you please explain the process through which the policy_tree function inputs the double_robust_scores (i.e., the Gamma.matrix from causal forest that is interpreted as rewards), and uses them to generate the rules? Are positive values or negative values considered desirable in these reward matrices? I generally would think lower or negative treatment effects would be desirable, but I wasn't clear on how these rewards values are generated in a causal forest.
I am asking, because in my current project where I am using my causal forest's double robust scores for the policy tree, I am noticing that I am getting suboptimal results when evaluating the policy tree on the test subsample (as compared to the truth in that same test subsample). Therefore, I wanted to confirm that the policy tree actually seeks to reduce the incidence of a binary outcome, rather than increase it.
Furthermore, if someone is trying to improve the performance of their policy tree, do you have any recommendations on how to do this, besides increasing the depth of the tree (which hasn't seemed to work for me thus far).
Hi @njawadekar. As detailed in the policytree documentation, the policy_tree algorithm seeks to maximize rewards (as can be seen in the code example below).
You should be able to minimize by multiplying the doubly-robust scores by -1 before running the policytree algorithm.
n <- 400
p <- 4
d <- 3
depth <- 2
# Classification task taken from policytree tests
X <- round(matrix(rnorm(n * p), n, p),2)
Gamma <- matrix(0, n, d)
best.tree <- policytree:::make_tree(X, depth = depth, d = d)
best.action <- policytree:::predict_test_tree(best.tree, X)
Gamma[cbind(1:n, best.action)] <- 100 * runif(n)
head(Gamma)
# [,1] [,2] [,3]
# [1,] 0.000000 4.469589 0
# [2,] 0.000000 17.467142 0
# [3,] 58.383523 0.000000 0
# [4,] 9.159190 0.000000 0
# [5,] 2.882921 0.000000 0
# [6,] 58.587015 0.000000 0
tree <- policytree:::policy_tree(X,Gamma,2)
# policy_tree object
# Tree depth: 2
# Actions: 1 2 3
# Variable splits:
# (1) split_variable: X4 split_value: 0.29
# (2) * action: 1
# (3) split_variable: X4 split_value: 0.38
# (4) * action: 3
# (5) * action: 2
all(apply(Gamma,1,which.max) == predict(tree, X))
# [1] TRUE
Thanks for the reply! This seems to have resolved my issue.