pymc-experimental icon indicating copy to clipboard operation
pymc-experimental copied to clipboard

Creating function to transform Graphviz DPG model into Causal DAG

Open cetagostini opened this issue 4 months ago • 13 comments

Hey team — following up on https://github.com/pymc-devs/pymc/issues/6716.

The main idea is to allow users to derive a quick causal-style DAG from a PyMC model’s Graphviz diagram (the DPG, i.e., the directed probabilistic/model graph from pm.model_to_graphviz). This is feasible under clear assumptions and with explicit limitations.

Different regression specifications can correspond to the same underlying causal DAG. For example, in a simple chain A → B → C, depending on the estimand one might fit C ~ B + ε or C ~ A + ε. Both are compatible with the same causal story, yet pm.model_to_graphviz will (rightly) produce different model graphs for those different PyMC specifications.

What this is (and isn’t)

  • Is: a visualization helper that maps a PyMC model graph to a compact, causal-style DAG by:

    • keeping only pm.Data nodes,
    • optionally showing selected unobserved modeled effects (e.g., priors) as dashed ellipses,
    • drawing edges to the first data node(s) encountered along any directed path (don’t traverse past a data node).
  • Is not: causal discovery or identification. It does not infer causal structure; it presents a causal-style view implied by a specific PyMC model.

Using @drbenvincent’s blog post “Causal inference: have you been doing science wrong all this time?” as inspiration, I'll add some examples given the following DAG.

Image

Example 1: Adjustment model Y ~ X + Q

Suppose we care about the effect of X on Y, given the causal DAG we must adjust for confounder Q:

with pm.Model() as x_y_model:
    _Q = pm.Data("_Q", df["Q"])
    _X = pm.Data("_X", df["X"])
    _Y = pm.Data("_Y", df["Y"])

    beta_q = pm.Normal("beta_q")
    beta_x = pm.Normal("beta_x")
    sigma_y = pm.HalfNormal("sigma")

    Y = pm.Normal("Y", mu=beta_x * _X + beta_q * _Q, sigma=sigma_y, observed=_Y)

causal_src = pymc_dpg_to_causal_dag(pm.model_to_graphviz(x_y_model).source)

Result (causal-style): _X → _Y and _Q → _Y — exactly what we expect for the adjustment set. 😕

Image

[!Note] If we want to capture other effects, such as Q over Y or Q over P. We could make more regressions with other adjustments sets and this function will bring a C-DAG different for each (basically, each DPG maps a C-DAG style).

Example 2: Fully specified “luxury” model

If we know the true causal DAG and build a fully specified PyMC model that mirrors it, the helper reproduces that DAG:

with pm.Model() as full_luxury_model:
    _Q = pm.Data("_Q", df["Q"])
    _X = pm.Data("_X", df["X"])
    _Y = pm.Data("_Y", df["Y"])
    _P = pm.Data("_P", df["P"])

    # slopes
    qx = pm.Normal("qx")      # X ~ Q
    xy = pm.Normal("xy")      # Y ~ X
    qy = pm.Normal("qy")      # Y ~ Q
    xp = pm.Normal("xp")      # P ~ X
    yp = pm.Normal("yp")      # P ~ Y

    # scales
    sigma_x = pm.HalfNormal("sigma_x")
    sigma_y = pm.HalfNormal("sigma_y")
    sigma_p = pm.HalfNormal("sigma_p")

    Q = pm.Normal("Q", observed=_Q)
    X = pm.Normal("X", mu=qx * _Q,                 sigma=sigma_x, observed=_X)
    Y = pm.Normal("Y", mu=xy * _X + qy * _Q,       sigma=sigma_y, observed=_Y)
    P = pm.Normal("P", mu=xp * _X + yp * _Y,       sigma=sigma_p, observed=_P)

causal_src = pymc_dpg_to_causal_dag(pm.model_to_graphviz(full_luxury_model).source)

Output: The original causal diagram 🔥

Image

Proposed API

pymc_dpg_to_causal_dag(
    model_or_dot,                    # pm.Model | graphviz.Digraph | DOT string
    *,
    first_hit_only=True,             # connect to closest downstream Data
    node_style='style="filled"',     # Data node style
    unobserved_vars=None,            # e.g. ["intercept", "sigma"]; dashed ellipses
) -> graphviz.Source                 # render with .render() or display in notebooks
  • First-hit rule: From each source node, walk forward; when you hit any pm.Data node, draw an edge and stop (don’t traverse beyond that data node). This avoids dense transitive edges and respects mediators (e.g., _Q → _Y, not _Q → _P via _Y).
  • Unobserved variables: If provided, render the listed node IDs (that exist in the graph) as dashed ellipses and connect them to their closest downstream data nodes.

[!Note] We can add this to pymc extras, check how community use it, then decide if we want to be in the main PyMC repo.

Limitations & guidance

  • This helper does not discover causal structure or validate identification.
  • Multiple PyMC model specifications that implement valid adjustment sets can render to the same or different causal-style DAG — that’s intended. (The relationship will be 1:1 with DPG)
  • Best used when the causal DAG is known and you want to check the PyMC model mirrors it; communicate the causal story more cleanly than the full probabilistic graph.

If there’s interest, I’m happy to open a PR adding this as a documented recipe (with tests) or a small utility in an examples module. Check my draft in Google Colab.

cc: @drbenvincent @jessegrabowski @ricardoV94 @cluhmann @twiecki

cetagostini avatar Aug 28 '25 22:08 cetagostini

I don't understand the introductory remarks here. They seem to confuse the issue.

Different regression specifications can correspond to the same underlying causal DAG. For example, in a simple chain A → B → C, depending on the estimand one might fit C ~ B + ε or C ~ A + ε. Both are compatible with the same causal story, yet pm.model_to_graphviz will (rightly) produce different model graphs for those different PyMC specifications.

So the existing graphiz tools do the correct thing causally in this example. So what's the point of this example?

If we naïvely transform each DPG into a causal DAG, we risk showing different causal diagrams for models that implement the same causal structure (e.g., when using minimal adjustment sets). The proposed helper instead collapses through non-data nodes and connects variables to their closest downstream data nodes (“first-hit” rule), yielding a stable causal-style view.

How would we "naïvely transform each DPG into a causal DAG"? Highlighting risks of an unspecified method seems distracting.

This introduction seems distracting and best and misleading at worst.

For the examples of the new method, can we see the graphviz representation so that we can compare the "traditional" graphical representation with the proposed new representation?

cluhmann avatar Aug 28 '25 22:08 cluhmann

@cluhmann Yes, take out that initial part was really bad phrasing!

cetagostini avatar Aug 28 '25 22:08 cetagostini

Examples here!

Estimate X effect over Y conditioning on Q

The PyMC DPG.

Image

The PyMC Causal like DAG.

Image

Full luxury bayes model

The PyMC DPG.

Image - Old version from Ben - New version its to big for screenshot [Check my draft in Google Colab](https://colab.research.google.com/drive/1f4e9BC60Y8-rb2FQZojcZXK8zSifwZw2#scrollTo=4OAq1BTQn9Rv) -

The PyMC Causal like DAG.

Image

cetagostini avatar Aug 28 '25 22:08 cetagostini

[Old version from Ben - New version its to big for screenshot take a look to the notebook]

I don't believe this graph corresponds to the model you have written here. The edge from Q to _Q seems backwards, for example. And there shouldn't be any direct edge from Q to X.

cluhmann avatar Aug 28 '25 23:08 cluhmann

I don't believe this graph corresponds to the model you have written here. The edge from Q to _Q seems backwards, for example. And there shouldn't be any direct edge from Q to X.

Yes, take a look to the notebook, this was an equivalent but old model from Ben post

cetagostini avatar Aug 29 '25 07:08 cetagostini

Does it always just make a graph with only the data nodes? If only this did that, but the edges are all dropped

data_nodes = [var.name for var in model.data_vars]
pm.model_to_graphviz(model, var_names=data_nodes)

williambdean avatar Aug 29 '25 07:08 williambdean

@williambdean thats an option but yeah, will lack the fact of adding certain random variables or deterministic's as unobserved vars in the C-DAG, and as you mentioned edges are not there because the data doesn't point directly to other data (I guess), instead points to other random variables, and this way doesn't have a method to propagate the line to the data.

cetagostini avatar Aug 29 '25 11:08 cetagostini

I don't believe this graph corresponds to the model you have written here. The edge from Q to _Q seems backwards, for example. And there shouldn't be any direct edge from Q to X.

Yes, take a look to the notebook, this was an equivalent but old model from Ben post

Ben's notebook doesn't appear to have any Q or X in it. Nor does it seem to have any PyMC model specified in it.

cluhmann avatar Aug 29 '25 13:08 cluhmann

@cluhmann not sure if we are seeing the same but this is @drbenvincent initial mode:

with pm.Model() as model:
    # data
    _Q = pm.MutableData("_Q", df["Q"])
    _X = pm.MutableData("_X", df["X"])
    _Y = pm.MutableData("_Y", df["Y"])
    _P = pm.MutableData("_P", df["P"])

    # priors on slopes
    # x ~ q
    qx = pm.Normal("qx")
    # y ~ x + q
    xy = pm.Normal("xy")
    qy = pm.Normal("qy")
    # p ~ x + y
    xp = pm.Normal("xp")
    yp = pm.Normal("yp")

    # priors on sd's
    sigma_x = pm.HalfNormal("sigma_x")
    sigma_y = pm.HalfNormal("sigma_y")
    sigma_p = pm.HalfNormal("sigma_p")
    
    # model
    Q = pm.Normal("Q", observed=_Q)
    X = pm.Normal("X", mu=qx*Q, sigma=sigma_x, observed=_X)
    Y = pm.Normal("Y", mu=xy*X + qy*Q, sigma=sigma_y, observed=_Y)
    P = pm.Normal("P", mu=xp*X + yp*Y, sigma=sigma_p, observed=_P)

What do you mean by "doesn't appear to have any Q or X in it"? This is comparable to mine build above in the description and draft notebook.

cetagostini avatar Aug 29 '25 13:08 cetagostini

That's not the same model. For example, you have:

X = pm.Normal("X", mu=qx*_Q, sigma=sigma_x, observed=_X)

whereas @drbenvincent has:

X = pm.Normal("X", mu=qx*Q, sigma=sigma_x, observed=_X)

Please just write a model and then present both the graphviz representation and your new representation.

cluhmann avatar Aug 29 '25 14:08 cluhmann

@cluhmann If you want to observe, then fine I can added in the notebook. Nevertheless the reason of the change its because @drbenvincent its modeling things which are not needed, basically for those, Q, X, etc. We are not interested, we don’t want to model Q (as example), we just want to condition on its observed values.

Probably could make sense for some (?), but not for all. E.g: Doing Q = pm.Normal("Q", observed=_Q) since _Q is exogenous and doesn’t depend on any parameters, doesn’t affect inference at all, so its redundant and adds complexity in the sampler. Taking that out and apply as I shared above, makes the model closer to the DGP, which its the real point here.

Again, I can added in the notebook, you can take a look but will bring another DPG which will map a C-DAG style. In this case, probably just nodes without connections because the model is written in a way where no relationships are being state properly. Meaning, we can't infer what user wants to condition on.

cetagostini avatar Aug 29 '25 20:08 cetagostini

@cluhmann I added the example with the old model from the article -> notebook here.

cetagostini avatar Aug 29 '25 20:08 cetagostini

...no access

cluhmann avatar Aug 29 '25 21:08 cluhmann