pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Abstract Graph Iteration

Open williambdean opened this issue 1 year ago • 15 comments

Description

Pulled any graph related information into two methods:

  1. create_plates: Get plate meta information and the nodes that are associated with each plate
  2. edges: Edges between nodes as a generator

These two methods now are the core logic of make_graph and make_networkx which people can exploit personal use-cases

Coming with this are two new classes:

  1. PlateMeta which stores the dim names and dlen from before
  2. NodeMeta which will store the variable and it's NodeType in the graph (introduced in #7302)

Also need to figure out an example where "{var_name}_dim{dlen}" was used. @ricardoV94 would you know an example? Think this would assume that there is only one var on the plate since the name is being used?

Related Issue

  • [x] Closes #7319
  • [ ] Related to

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pymc--7392.org.readthedocs.build/en/7392/

williambdean avatar Jun 26 '24 08:06 williambdean

The name Plate and PlateMeta come from the historical get_plates method of ModelGraph. However, get_plates also get scalars which were "" before and now Plate(meta=None, variables=[...])

Is Plate still a good name? It is a collection of variables all with the same dims. Plate in my mind is Bayesian graphical model and might deviate with the scalars. PlateMeta might be more suited as DimsMeta since the names and sizes are the dims of the variables

Any thoughts here on terminology?

williambdean avatar Jun 27 '24 10:06 williambdean

I'm okay with Plate or Cluster. Why the Meta in it?

ricardoV94 avatar Jun 27 '24 10:06 ricardoV94

I'm okay with Plate or Cluster. Why the Meta in it?

Meta would be information about the variables / plate to construct a plate label. Previously it was always " x ".join([f"{dname} ({dlen})" for ...] Meta just provides the parts to construct based on components presented before

williambdean avatar Jun 27 '24 11:06 williambdean

I don't love the word meta, it's too abstract. Plate.dim_names, Plate.dim_lengths, Plate.vars? or Plate.var_names if that's what we are storing

ricardoV94 avatar Jun 27 '24 11:06 ricardoV94

I don't love the word meta, it's too abstract. Plate.dim_names, Plate.dim_lengths, Plate.vars? or Plate.var_names if that's what we are storing

I think itd be nice to keep the names and sizes together since they are related. How about DimInfo

williambdean avatar Jun 27 '24 11:06 williambdean

Is the question whether we represent a data structure that looks like (in terms of access): ((dims_names, dim_lengths), var_names) vs (dim_names, dim_lengths, var_names)? Seems like a tiny detail. I have a slight preference for having it flat but up to you

ricardoV94 avatar Jun 27 '24 11:06 ricardoV94

This PR refreshed my mind that #6485 and #7048 exist.

To summarize: We can have variables that have entries in named_vars_to_dims of type tuple[str | None, ...]. We can also have variables that don't show up in named_vars_to_dims at all? Which is odd, since we already allow None to represent unknown dims, so all variables could conceivable have entries (or we would not allow None).

Then dims can have coords or not, but always have dim_lengths, which always work when we do the fast_eval for dim_lengths, so that's not a problem that shows up here. I think that doesn't matter here for us. Just mentioning in case I brought it up by mistake in my comments.

ricardoV94 avatar Jun 27 '24 11:06 ricardoV94

Is the question whether we represent a data structure that looks like (in terms of access): ((dims_names, dim_lengths), var_names) vs (dim_names, dim_lengths, var_names)? Seems like a tiny detail. I have a slight preference for having it flat but up to you

There is also the NodeType which is why I went for the small dataclass wrapper that contains TensorVariable and the preprocessed label. I think have a small data structure isn't the end of the world but also helps structure the problem a bit more. The user can clearly see what is part of the new data structures in my mind

williambdean avatar Jun 28 '24 08:06 williambdean

Need to

  • [x] cover the {var_name}_dim{d} case still #6335. Unless the naming should changed?
  • [ ] Fix the previous tests

The 6335 comes up with this example:

# Current main branch
coords = {
    "obs": range(5),
}
with pm.Model(coords=coords) as model:
    data = pt.as_tensor_variable(
        np.ones((5, 3)),
        name="data",
    )
    pm.Deterministic("C", data, dims=("obs", None))
    pm.Deterministic("D", data, dims=("obs", None))
    pm.Deterministic("E", data, dims=("obs", None))

pm.model_to_graphviz(model)

Result: previous-with-none

Which makes sense that they will not be on the same plate, right?

williambdean avatar Jun 28 '24 10:06 williambdean

I did just catch this bug: It comes from the make_compute_graph which causes a self loop

from pymc.model_graph import ModelGraph

coords = {
    "obs": range(5),
}
with pm.Model(coords=coords) as model:
    data = pt.as_tensor_variable(
        np.ones((5, 3)),
        name="C",
    )
    pm.Deterministic("C", data, dims=("obs", None))

error_compute_graph = ModelGraph(model).make_compute_graph() # defaultdict(set, {"C": {"C"}})
# Visualize error:
pm.model_to_graphviz(model)

Result:

compute-graph-bug

Shall I make a separate issue?

williambdean avatar Jun 28 '24 10:06 williambdean

I think they should be in the same plate, because in the absense of dims, the shape is used to cluster RVs?

ricardoV94 avatar Jun 28 '24 10:06 ricardoV94

Self loop is beautiful :)

ricardoV94 avatar Jun 28 '24 10:06 ricardoV94

I think they should be in the same plate, because in the absense of dims, the shape is used to cluster RVs?

How should the {var_name}_dim{d} be handled then to put them on the same plate?

Just "dim{d} ({dlen})"?

williambdean avatar Jun 29 '24 06:06 williambdean

Just the length? how does a plate without any dims look like?

I imagine the mix would be 50 x trial(30) or however the trial dim is usually displayed.

WDYT?

ricardoV94 avatar Jun 29 '24 06:06 ricardoV94

Just the length? how does a plate without any dims look like?

I imagine the mix would be 50 x trial(30) or however the trial dim is usually displayed.

WDYT?

This mixing of dlen and "{dname} ({dlen})" is what I had in mind. That is the current behavior.

Here are some examples:

import numpy as np
import pymc as pm
import pytensor.tensor as pt

coords = {
    "obs": range(5),
}
with pm.Model(coords=coords) as model:
    data = pt.as_tensor_variable(
        np.ones((5, 3)),
        name="data",
    )
    C = pm.Deterministic("C", data, dims=("obs", None))
    D = pm.Deterministic("D", data, dims=("obs", None))
    E = pm.Deterministic("E", data, dims=("obs", None))

pm.model_to_graphviz(model)

same-plate

# Same as above
pm.model_to_graphviz(model, include_dim_lengths=False)

same-plate-without

And larger example with various items:

import numpy as np
import pymc as pm
import pytensor.tensor as pt

coords = {
    "obs": range(5),
    "covariates": ["X1", "X2", "X3"],
}
with pm.Model(coords=coords) as model: 
    data1 = pt.as_tensor_variable(
        np.ones((5, 3)),
        name="data1",
    )
    data2 = pt.as_tensor_variable(
        np.ones((5, 3)),
        name="data2",
    )
    C = pm.Deterministic("C", data1, dims=("obs", None))
    CT = pm.Deterministic("CT", C.T, dims=(None, "obs"))
    D = pm.Deterministic("D", C @ CT, dims=("obs", "obs"))

    E = pm.Deterministic("E", data2, dims=("obs", None))
    beta = pm.Normal("beta", dims="covariates")
    pm.Deterministic("product", E[:, None, :] * beta[:, None], dims=("obs", None, "covariates"))

pm.model_to_graphviz(model)

larger-example

williambdean avatar Jun 30 '24 07:06 williambdean

Codecov Report

Attention: Patch coverage is 76.66667% with 28 lines in your changes missing coverage. Please review.

Project coverage is 92.18%. Comparing base (7af0a87) to head (e30f6d9). Report is 86 commits behind head on main.

Files with missing lines Patch % Lines
pymc/model_graph.py 76.66% 28 Missing :warning:
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7392      +/-   ##
==========================================
- Coverage   92.19%   92.18%   -0.01%     
==========================================
  Files         103      103              
  Lines       17214    17249      +35     
==========================================
+ Hits        15870    15901      +31     
- Misses       1344     1348       +4     
Files with missing lines Coverage Δ
pymc/model_graph.py 87.25% <76.66%> (-0.13%) :arrow_down:

... and 5 files with indirect coverage changes

codecov[bot] avatar Jul 03 '24 12:07 codecov[bot]

Thanks @wd60622

ricardoV94 avatar Jul 03 '24 12:07 ricardoV94