pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add function to mermaid diagram

Open williambdean opened this issue 6 months ago • 11 comments

Description

I explored using variables directly with is definitely doable. Some helpful functions where:

def _get_edges(variable):
    if variable.owner is None:
        return

    yield from ((variable, input_var) for input_var in variable.owner.inputs)

    for variable in variable.owner.inputs:
        yield from get_edges(variable)


def get_edges(variable):
    return list(_get_edges(variable))


def get_nodes(variable=None):
    edges = get_edges(variable)

    nodes = set()
    for child, parent in edges:
        nodes.add(child)
        nodes.add(parent)

    return nodes

Using the pydot formatter which already exists in pytensor, can provide some version of this already.

Some examples of this:

import pytensor
import pytensor.tensor as pt
from pytensor.mermaid import function_to_mermaid

alpha = pt.scalar("alpha")
beta = pt.vector("beta")
noise = pt.scalar("noise")

X = pt.matrix("X")

y = pt.dot(X, beta) + alpha + noise
y.name = "y"

fn = pytensor.function([X, alpha, beta, noise], y)
mermaid_code = function_to_mermaid(fn)

print(mermaid_code)
graph TD
%% Nodes:
n1["DimShuffle"]
n1@{ shape: rounded }
n2["noise"]
n2@{ shape: rect }
style n2 fill:#32CD32
n4["DimShuffle"]
n4@{ shape: rounded }
n5["alpha"]
n5@{ shape: rect }
style n5 fill:#32CD32
n7["Shape_i"]
n7@{ shape: rounded }
style n7 fill:#00FFFF
n8["X"]
n8@{ shape: rect }
style n8 fill:#32CD32
n8["X"]
n8@{ shape: rect }
style n8 fill:#32CD32
n10["AllocEmpty"]
n10@{ shape: rounded }
n12["CGemv"]
n12@{ shape: rounded }
n13["1.0"]
n13@{ shape: rect }
style n13 fill:#00FF7F
n14["beta"]
n14@{ shape: rect }
style n14 fill:#32CD32
n15["0.0"]
n15@{ shape: rect }
style n15 fill:#00FF7F
n17["Elemwise"]
n17@{ shape: rounded }
n18["y"]
n18@{ shape: rect }
style n18 fill:#1E90FF

%% Edges:
n2 --> n1
n5 --> n4
n8 --> n7
n7 --> n10
n10 --> n12
n13 --> n12
n8 --> n12
n14 --> n12
n15 --> n12
n12 --> n17
n4 --> n17
n1 --> n17
n17 --> n18
import pytensor.tensor as pt
import pytensor
from pytensor.mermaid import function_to_mermaid

x, y, z = pt.scalars('xyz')
e = x * y
op = pytensor.compile.builders.OpFromGraph([x, y], [e])
e2 = op(x, y) + z
op2 = pytensor.compile.builders.OpFromGraph([x, y, z], [e2])
e3 = op2(x, y, z) + z
f = pytensor.function([x, y, z], [e3])

print(function_to_mermaid(f))
graph TD
%% Nodes:
n1["OpFromGraph"]
n1@{ shape: rounded }
n2["x"]
n2@{ shape: rect }
style n2 fill:#32CD32
n3["y"]
n3@{ shape: rect }
style n3 fill:#32CD32
n4["z"]
n4@{ shape: rect }
style n4 fill:#32CD32
n4["z"]
n4@{ shape: rect }
style n4 fill:#32CD32
n6["Elemwise"]
n6@{ shape: rounded }
n7["dscalar"]
n7@{ shape: rect }
style n7 fill:#1E90FF

%% Edges:
n2 --> n1
n3 --> n1
n4 --> n1
n1 --> n6
n4 --> n6
n6 --> n7

Related Issue

  • [ ] Closes #
  • [ ] Related to #

Checklist

Type of change

  • [ ] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

williambdean avatar Jun 20 '25 21:06 williambdean

Seeing that tests are failing because of no pydot. Might be able to mock that behavior...

williambdean avatar Jun 20 '25 21:06 williambdean

Rather install and test properly. Elemwise isn't a great name though, we have to check the function that extracts the name It should use str(op)

ricardoV94 avatar Jun 20 '25 21:06 ricardoV94

Rather install and test properly.

Do you like the route of using pydot? Don't think it would be too hard to implement a custom formatter. We just need a light graph representation.

Elemwise isn't a great name though, we have to check the function that extracts the name It should use str(op)

Sure. Shall I change for the formatter then or go a different route?

williambdean avatar Jun 20 '25 22:06 williambdean

I think pydot is overkill but I wouldn't mock if you're using it. Your original function was fine except it will iterate some edges repeatedly.

I'm sure you can repurpose FunctionGraph or something from graph.basic

ricardoV94 avatar Jun 20 '25 22:06 ricardoV94

Your original function was fine except it will iterate some edges repeatedly.

Can you explain the case that is missing or example that will fail

I'm sure you can repurpose FunctionGraph or something from graph.basic

I will explore. Is dprint logic general enough to support?

williambdean avatar Jun 20 '25 22:06 williambdean

Can you explain the case that is missing or example that will fail

Don't know if it will fail, but a graph like this, you may end up navigating x -> exp_x twice:

x = pt.scalar("x")
y = pt.exp(x)
out = y + y * 2

I will explore. Is dprint logic general enough to support?

Not sure what you mean. I meant that in those modules you have utilities to iterate over a graph. It's unlikely you have to invent something new for that goal.

ricardoV94 avatar Jun 21 '25 08:06 ricardoV94

The d3viz based on

x = pt.scalar("x")
y = pt.exp(x)
out = y + y * 2

looks like (I constructed this from hand):

graph TD
A["x"]
style A fill:#32CD32
B["Elemwise"]
B@{ shape: rounded }
style B fill:#FF00FF
C["dscalar"]
style C fill:#1E90FF

A --> B
B --> C

Is that as expected? If not, the PydotFormatter is not correct

williambdean avatar Jun 21 '25 13:06 williambdean

Recent change has it looking like

        graph TD
        %% Nodes:
        n1["Composite"]
        n1@{ shape: rounded }
        n2["x"]
        n2@{ shape: rect }
        style n2 fill:#32CD32
        n3["out"]
        n3@{ shape: rect }
        style n3 fill:#1E90FF

        %% Edges:
        n2 --> n1
        n1 --> n3

williambdean avatar Jun 21 '25 13:06 williambdean

Seems like you are compiling/rewriting the graph, otherwise the Composite wouldn't be introduced

ricardoV94 avatar Jun 21 '25 14:06 ricardoV94

Well the d3 stuff only works for pytensor.function. I think it might be useful to deviate from that and also support TensorVariable. Thoughts if the direction of these two would be different?

williambdean avatar Jul 14 '25 03:07 williambdean

We can make d3 work easily for variables as well no? Just need a FunctionGraph? fg = FunctionGraph(outputs=[var], clone=False) and Viz that?

Also don't narrow anything on Tensorvariables but try to work with any Variables (for instance XTrensorVariables but even stuff like Slice and RNGs)

ricardoV94 avatar Jul 14 '25 06:07 ricardoV94