scikit-tree
scikit-tree copied to clipboard
Add meta-estimator for causal trees
Is your feature request related to a problem? Please describe.
Given any tree model, we can fit a causal tree by expanding the fit API to allow a treatment group to be passed in.
Describe the solution you'd like Something similar in functionality to https://github.com/microsoft/EconML/blob/main/econml/dml/causal_forest.py, but with way simpler implementation.
class CausalTree:
def __init__(tree_model, ...):
# instantiated tree model
self.tree_model = tree_model
def fit(self, X, y, T):
self.tree_model.fit(X, T)
self.tree_model.fit(T, y)
self.tree_model.fit(X, y)
# combine them to get the estimates for P(y | do(X))
Additional context Once this works, we could PR to econml. I think we should be wary and make sure all necessary functionality of causal trees is supported.
Note the instantiation process for a causaltree will be very similar to that of a "sklearn Pipeline", where the tree_model should be instantiated outside of the CausalTree. Just makes a simpler API.
https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html
Ideally this causal tree class allows all the functionality that econml currently supports:
- grf (probably just for sake of continuity)
- doubly robust trees
- dml trees
- orf tree (note this is entirely in Python, so we can easily cythonize it)
https://github.com/microsoft/EconML/tree/main/econml. Tbh I'm not entirely sure on the differences between dml tree, grf and the causal tree implementation in econml.
IIUC, GRF are just regular trees with:
- honesty
- some fancy criterion that solves a moment equation
I don't think we really need to implement this as a high priority since it seems the DML trees, model fit for propensity and then model fit for outcome seems like the better bet anyways right?
WDYT @sampan501
Agreed seems low priority to me
Do you know if the other types of trees need to use more exotic criterions? E.g. in econml, the 'het' and 'mse' are these moment equation solvers. I guess for continuity, we should ideally simplify their implementation and have it ourselves as well, but can the other normal sklearn criterions be used?
I don't see why not? but maybe I missed something.
Not that I know of, but we can add a parameter when building the object
Okay that's good to know. The parameter would just be the normal criterion keyword argument, which can be 'gini', 'poisson', etc.
Re 'het' and the 'mse' (note this mse is not the same as the MSE currently in sklearn, so we should probably call it something else...) criterion in econml, we'll have to replicate the functionality here: https://github.com/py-why/EconML/blob/main/econml/grf/_criterion.pyx. Doesn't look too bad.
Some API issues to figure out.
In sklearn, we have:
- fit(X, y, sample_weight)
- predict(X)
- score(X)
Those are probably the main API we want to "override". In causal land, we want something like:
- fit(X, y, sample_weight, t=None, W=None, Z=None)
- predict(X) -> what does this even mean? Predict CATE? Predict average treatment effect (ATE)? Something else?
- score(X) -> w/o having a good definition of predict, not sure either
Econml also exposes a class API for getting confidence intervals. I think this is not necessary and overcomplicates the classes. We should just provide a function to get confidence intervals for the predicted causal effects.
Based on the answers above for causal, we might want to expose API for getting specific types of effects like they do in econml, such as CATE, ATE, marginal_CATE, etc.
Some possible solutions
- fit still takes X, y, sample_weight in that order, but also adds optional kwargs, such as
T,Z, which are checked for depending on the meta-estimator. An error is raised if they are not present, but required e.g. in DML. - Same with predict and score.
Just have to verify that at least this works with sklearn's testing function: parametrize_with_checks. If so, then at least we can be assured that the trees are pretty compatible with the rest of the sklearn codebase.
Questions to Resolve
Open questions are:
- how does econml use
W, which are "controls". It is honestly not super clear to me what the difference between W and X is and moreover, how does this impact the fitting/predicting - what does econml currently use
predictfor? What does it usescorefor?
Along the lines of causal trees, adding some notes to be aware of:
I've finished perusing and understanding the code in EconML about GRF and the general GRF paper. At an implementation level, this means we have two distinctly "different" kinds of trees that would approach the problems of:
- general regression (i.e. (X, y))
- CATE estimation (i.e. (X, y, T))
- LATE estimation (i.e. (X, y, T, Z))
The first is using GRF, which diverges from the sklearn semantics because this is explicitly a "gradient-based" tree. We are solving local gradients at every split node using the GRFCriterion.
The second is using the DML/DR multiple fitting approaches with a honest regression tree as its basic ingredient.
The GRF implementation is quite a bit more complex and if it were up to us, I would be inclined to leave it out, but to achieve feature parity with econml, we should have a refactored version. The DML/DR approaches are fairly more straightforward as they can be implemented as meta estimators as we have discussed in this issue and #52
Besides the GRF https://arxiv.org/pdf/1510.04342.pdf is a good paper