scikit-tree icon indicating copy to clipboard operation
scikit-tree copied to clipboard

Add meta-estimator for causal trees

Open adam2392 opened this issue 2 years ago • 10 comments
trafficstars

Is your feature request related to a problem? Please describe. Given any tree model, we can fit a causal tree by expanding the fit API to allow a treatment group to be passed in.

Describe the solution you'd like Something similar in functionality to https://github.com/microsoft/EconML/blob/main/econml/dml/causal_forest.py, but with way simpler implementation.


class CausalTree:
   def __init__(tree_model, ...):
        # instantiated tree model
        self.tree_model = tree_model
   def fit(self, X, y, T):
         self.tree_model.fit(X, T)
         self.tree_model.fit(T, y)
         self.tree_model.fit(X, y)
         # combine them to get the estimates for P(y | do(X))

Additional context Once this works, we could PR to econml. I think we should be wary and make sure all necessary functionality of causal trees is supported.

Note the instantiation process for a causaltree will be very similar to that of a "sklearn Pipeline", where the tree_model should be instantiated outside of the CausalTree. Just makes a simpler API.

https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html

adam2392 avatar Mar 02 '23 20:03 adam2392

Ideally this causal tree class allows all the functionality that econml currently supports:

  • grf (probably just for sake of continuity)
  • doubly robust trees
  • dml trees
  • orf tree (note this is entirely in Python, so we can easily cythonize it)

https://github.com/microsoft/EconML/tree/main/econml. Tbh I'm not entirely sure on the differences between dml tree, grf and the causal tree implementation in econml.

adam2392 avatar Mar 06 '23 19:03 adam2392

IIUC, GRF are just regular trees with:

  • honesty
  • some fancy criterion that solves a moment equation

I don't think we really need to implement this as a high priority since it seems the DML trees, model fit for propensity and then model fit for outcome seems like the better bet anyways right?

WDYT @sampan501

adam2392 avatar Mar 08 '23 05:03 adam2392

Agreed seems low priority to me

sampan501 avatar Mar 08 '23 13:03 sampan501

Do you know if the other types of trees need to use more exotic criterions? E.g. in econml, the 'het' and 'mse' are these moment equation solvers. I guess for continuity, we should ideally simplify their implementation and have it ourselves as well, but can the other normal sklearn criterions be used?

I don't see why not? but maybe I missed something.

adam2392 avatar Mar 08 '23 16:03 adam2392

Not that I know of, but we can add a parameter when building the object

sampan501 avatar Mar 08 '23 16:03 sampan501

Okay that's good to know. The parameter would just be the normal criterion keyword argument, which can be 'gini', 'poisson', etc.

adam2392 avatar Mar 08 '23 21:03 adam2392

Re 'het' and the 'mse' (note this mse is not the same as the MSE currently in sklearn, so we should probably call it something else...) criterion in econml, we'll have to replicate the functionality here: https://github.com/py-why/EconML/blob/main/econml/grf/_criterion.pyx. Doesn't look too bad.

adam2392 avatar Mar 08 '23 21:03 adam2392

Some API issues to figure out.

In sklearn, we have:

  • fit(X, y, sample_weight)
  • predict(X)
  • score(X)

Those are probably the main API we want to "override". In causal land, we want something like:

  • fit(X, y, sample_weight, t=None, W=None, Z=None)
  • predict(X) -> what does this even mean? Predict CATE? Predict average treatment effect (ATE)? Something else?
  • score(X) -> w/o having a good definition of predict, not sure either

Econml also exposes a class API for getting confidence intervals. I think this is not necessary and overcomplicates the classes. We should just provide a function to get confidence intervals for the predicted causal effects.

Based on the answers above for causal, we might want to expose API for getting specific types of effects like they do in econml, such as CATE, ATE, marginal_CATE, etc.

Some possible solutions

  • fit still takes X, y, sample_weight in that order, but also adds optional kwargs, such as T, Z, which are checked for depending on the meta-estimator. An error is raised if they are not present, but required e.g. in DML.
  • Same with predict and score.

Just have to verify that at least this works with sklearn's testing function: parametrize_with_checks. If so, then at least we can be assured that the trees are pretty compatible with the rest of the sklearn codebase.

Questions to Resolve

Open questions are:

  1. how does econml use W, which are "controls". It is honestly not super clear to me what the difference between W and X is and moreover, how does this impact the fitting/predicting
  2. what does econml currently use predict for? What does it use score for?

adam2392 avatar Mar 14 '23 18:03 adam2392

Along the lines of causal trees, adding some notes to be aware of:

I've finished perusing and understanding the code in EconML about GRF and the general GRF paper. At an implementation level, this means we have two distinctly "different" kinds of trees that would approach the problems of:

  • general regression (i.e. (X, y))
  • CATE estimation (i.e. (X, y, T))
  • LATE estimation (i.e. (X, y, T, Z))

The first is using GRF, which diverges from the sklearn semantics because this is explicitly a "gradient-based" tree. We are solving local gradients at every split node using the GRFCriterion.

The second is using the DML/DR multiple fitting approaches with a honest regression tree as its basic ingredient.

The GRF implementation is quite a bit more complex and if it were up to us, I would be inclined to leave it out, but to achieve feature parity with econml, we should have a refactored version. The DML/DR approaches are fairly more straightforward as they can be implemented as meta estimators as we have discussed in this issue and #52

adam2392 avatar Mar 23 '23 20:03 adam2392

Besides the GRF https://arxiv.org/pdf/1510.04342.pdf is a good paper

adam2392 avatar Jun 12 '23 18:06 adam2392