MathOptAI.jl icon indicating copy to clipboard operation
MathOptAI.jl copied to clipboard

Non-sequential connections for NN

Open odow opened this issue 1 year ago • 6 comments

People want skip connections, e.g., for ICNN.

There's nothing really blocking this. We just need to design the right approach.

odow avatar Sep 23 '24 21:09 odow

@mjgarc do you have a small example in torch of a model you'd like to add that has skip connections?

odow avatar Sep 23 '24 22:09 odow

I'm going to close this until we have a motivation. In the short-term, the VectorNonlinearOracle/GrayBox can trivially handle these. So it's only an issue if we want to combine them with a global solver.

odow avatar Sep 30 '25 04:09 odow

I'd like to do this with the CANOS GNN model for power flow: https://github.com/MOSSLab-MIT/pfdelta/tree/main/core/models

Robbybp avatar Nov 19 '25 22:11 Robbybp

Here's a Pytorch example constructing the model and running the forward pass:

import os
from core.datasets.opfdata import OPFData
from core.models.canos_opf import CANOS_OPF
dataset = OPFData(
    split="test",
    case_name="pglib_opf_case14_ieee",
    num_groups=1,
    topological_perturbations=True,
    pre_transform="mean_zero_variance_one",
    root=os.path.join("data", "opfdata"),
)
data = dataset[0]
canos = CANOS_OPF(
    dataset=dataset,
    hidden_dim=64,
    include_sent_messages=False,
    k_steps=3,
)
out = canos(data)
print("\nInput data:")
print(data)
print("\nCANOS GNN:")
print(canos)
print("\nOutput data:")
print(out)
Output
HeteroData(
  x=[1],
  objective=[1],
  bus={
    x=[14, 4],
    y=[14, 2],
    v_lims=[14, 2],
  },
  generator={
    x=[4, 11],
    y=[4, 2],
    p_lims=[4, 2],
    q_lims=[4, 2],
  },
  load={
    x=[11, 2],
    unnormalized=[11, 2],
  },
  shunt={
    x=[1, 2],
    unnormalized=[1, 2],
  },
  (bus, ac_line, bus)={
    edge_index=[2, 17],
    edge_attr=[17, 9],
    edge_label=[17, 4],
    branch_vals=[17, 9],
  },
  (bus, transformer, bus)={
    edge_index=[2, 3],
    edge_attr=[3, 11],
    edge_label=[3, 4],
    branch_vals=[3, 11],
  },
  (generator, generator_link, bus)={ edge_index=[2, 4] },
  (bus, generator_link, generator)={ edge_index=[2, 4] },
  (load, load_link, bus)={ edge_index=[2, 11] },
  (bus, load_link, load)={ edge_index=[2, 11] },
  (shunt, shunt_link, bus)={ edge_index=[2, 1] },
  (bus, shunt_link, shunt)={ edge_index=[2, 1] }
)

CANOS GNN:
CANOS_OPF(
  (encoder): Encoder(
    (node_projections): ModuleDict(
      (bus): Linear(in_features=4, out_features=64, bias=True)
      (generator): Linear(in_features=11, out_features=64, bias=True)
      (load): Linear(in_features=2, out_features=64, bias=True)
      (shunt): Linear(in_features=2, out_features=64, bias=True)
    )
    (edge_projections): ModuleDict(
      (('bus', 'ac_line', 'bus')): Linear(in_features=9, out_features=64, bias=True)
      (('bus', 'transformer', 'bus')): Linear(in_features=11, out_features=64, bias=True)
    )
  )
  (message_passing_layers): ModuleList(
    (0-2): 3 x InteractionNetwork(
      (edge_update): EdgeUpdate(
        (mlps): ModuleDict(
          (('bus', 'ac_line', 'bus')): Sequential(
            (0): Linear(in_features=192, out_features=64, bias=True)
            (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (2): ReLU()
            (3): Linear(in_features=64, out_features=64, bias=True)
          )
          (('bus', 'transformer', 'bus')): Sequential(
            (0): Linear(in_features=192, out_features=64, bias=True)
            (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (2): ReLU()
            (3): Linear(in_features=64, out_features=64, bias=True)
          )
          (('bus', 'generator_link', 'generator')): Sequential(
            (0): Linear(in_features=192, out_features=64, bias=True)
            (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (2): ReLU()
            (3): Linear(in_features=64, out_features=64, bias=True)
          )
          (('bus', 'load_link', 'load')): Sequential(
            (0): Linear(in_features=192, out_features=64, bias=True)
            (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (2): ReLU()
            (3): Linear(in_features=64, out_features=64, bias=True)
          )
          (('bus', 'shunt_link', 'shunt')): Sequential(
            (0): Linear(in_features=192, out_features=64, bias=True)
            (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (2): ReLU()
            (3): Linear(in_features=64, out_features=64, bias=True)
          )
        )
      )
      (node_update): NodeUpdate(
        (mlps): ModuleDict(
          (bus): Sequential(
            (0): Linear(in_features=128, out_features=64, bias=True)
            (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (2): ReLU()
            (3): Linear(in_features=64, out_features=64, bias=True)
          )
          (generator): Sequential(
            (0): Linear(in_features=128, out_features=64, bias=True)
            (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (2): ReLU()
            (3): Linear(in_features=64, out_features=64, bias=True)
          )
          (load): Sequential(
            (0): Linear(in_features=128, out_features=64, bias=True)
            (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (2): ReLU()
            (3): Linear(in_features=64, out_features=64, bias=True)
          )
          (shunt): Sequential(
            (0): Linear(in_features=128, out_features=64, bias=True)
            (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
            (2): ReLU()
            (3): Linear(in_features=64, out_features=64, bias=True)
          )
        )
      )
    )
  )
  (decoder): DecoderOPF(
    (node_decodings): ModuleDict(
      (bus): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (2): ReLU()
        (3): Linear(in_features=256, out_features=256, bias=True)
        (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (5): ReLU()
        (6): Linear(in_features=256, out_features=2, bias=True)
      )
      (generator): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (2): ReLU()
        (3): Linear(in_features=256, out_features=256, bias=True)
        (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (5): ReLU()
        (6): Linear(in_features=256, out_features=2, bias=True)
      )
    )
  )
)

Output data:
{'bus': tensor([[-0.3656,  0.9765],
        [-0.4713,  0.9684],
        [-0.4009,  0.9681],
        [ 0.0204,  0.9710],
        [ 0.3946,  0.9632],
        [-0.6469,  0.9731],
        [-0.8063,  0.9776],
        [-0.5907,  0.9679],
        [-0.6201,  0.9635],
        [-0.3241,  0.9674],
        [-0.0125,  0.9735],
        [-0.4481,  0.9725],
        [-0.3422,  0.9723],
        [-0.3960,  0.9805]], grad_fn=<StackBackward0>), 'generator': tensor([[1.8152, 0.0460],
        [0.3150, 0.0429],
        [0.0000, 0.2093],
        [0.0000, 0.0945]], grad_fn=<StackBackward0>), 'edge_preds': tensor([[-1.5357,  0.4356,  1.5888, -0.3234],
        [ 2.9970,  0.3546, -2.4655,  1.7929],
        [ 0.3178, -0.0859, -0.3125,  0.0671],
        [ 2.4634, -0.1819, -2.0877,  1.2900],
        [ 4.2400,  0.4547, -3.1228,  2.9244],
        [ 2.1173, -0.3385, -1.7908,  1.1598],
        [ 7.7669, -1.1045, -6.8812,  3.8982],
        [ 2.6597, -0.3414, -1.9390,  1.8504],
        [ 0.6210, -0.2276, -0.5642,  0.3459],
        [ 1.8648, -0.6184, -1.5947,  1.1504],
        [ 1.1493,  0.0707, -1.1493,  0.1786],
        [ 1.5848,  0.0239, -1.5848,  0.2738],
        [ 2.9910, -0.6017, -2.6746,  1.4423],
        [ 0.6935, -0.1769, -0.6258,  0.3210],
        [ 1.3678, -0.3171, -1.1971,  0.7166],
        [ 0.2377, -0.2376, -0.2113,  0.2615],
        [-0.1079,  0.0801,  0.1111, -0.0736],
        [-3.4147,  1.4269,  3.4147,  1.5697],
        [-1.0374,  0.2773,  1.0374,  0.4135],
        [-3.4440,  1.7422,  3.4440,  2.2227]], grad_fn=<StackBackward0>)}

Compared to the NNs we currently support, our input is a dict-like HeteroData object from PyTorchGeometric, the output is a dict, and the Pytorch module itself has several "submodules".

Robbybp avatar Dec 01 '25 00:12 Robbybp

Here's one way to wrap the NN so it accepts and returns vectors:

vectorcanos.py
import os
import torch
from torch import nn
from core.datasets.opfdata import OPFData
from core.models.canos_opf import CANOS_OPF


class VectorCanos(nn.Module):
    """
    The input vector contains:
      - bus features
      - generator features
      - load features
      - shunt features
      - (bus, ac_line, bus).edge_attr
      - (bus, transformer, bus).edge_attr

    The output vector contains:
      - out["bus"], out["generator"], out["edge_preds"]
    """

    def __init__(self, model: nn.Module, template):
        super().__init__()
        self.model = model
        self.template = template

        # Shapes used for slicing/unflattening
        self.bus_x_shape = tuple(template["bus"]["x"].shape)
        self.gen_x_shape = tuple(template["generator"]["x"].shape)
        self.load_x_shape = tuple(template["load"]["x"].shape)
        self.shunt_x_shape = tuple(template["shunt"]["x"].shape)
        self.line_ea_shape = tuple(template["bus", "ac_line", "bus"]["edge_attr"].shape)
        self.tr_ea_shape = tuple(template["bus", "transformer", "bus"]["edge_attr"].shape)

        self.sizes = [
            self.bus_x_shape[0] * self.bus_x_shape[1],
            self.gen_x_shape[0] * self.gen_x_shape[1],
            self.load_x_shape[0] * self.load_x_shape[1],
            self.shunt_x_shape[0] * self.shunt_x_shape[1],
            self.line_ea_shape[0] * self.line_ea_shape[1],
            self.tr_ea_shape[0] * self.tr_ea_shape[1],
        ]
        self.input_dim = sum(self.sizes)

        # Output shapes determined via a dry run
        out = model(template)
        self.out_bus_shape = tuple(out["bus"].shape)
        self.out_gen_shape = tuple(out["generator"].shape)
        self.out_edge_shape = tuple(out["edge_preds"].shape)
        self.output_dim = (
            self.out_bus_shape[0] * self.out_bus_shape[1]
            + self.out_gen_shape[0] * self.out_gen_shape[1]
            + self.out_edge_shape[0] * self.out_edge_shape[1]
        )

    def flatten_input(self, data) -> torch.Tensor:
        parts = [
            data["bus"]["x"].reshape(-1),
            data["generator"]["x"].reshape(-1),
            data["load"]["x"].reshape(-1),
            data["shunt"]["x"].reshape(-1),
            data["bus", "ac_line", "bus"]["edge_attr"].reshape(-1),
            data["bus", "transformer", "bus"]["edge_attr"].reshape(-1),
        ]
        return torch.cat(parts, dim=0)

    def unflatten_input(self, x_flat: torch.Tensor):
        x_flat = x_flat.to(dtype=torch.float32)
        chunks = torch.split(x_flat, self.sizes)

        data = self.template.clone()
        b0, b1 = self.bus_x_shape
        g0, g1 = self.gen_x_shape
        l0, l1 = self.load_x_shape
        s0, s1 = self.shunt_x_shape
        e0, e1 = self.line_ea_shape
        t0, t1 = self.tr_ea_shape

        data["bus"]["x"] = chunks[0].view(b0, b1)
        data["generator"]["x"] = chunks[1].view(g0, g1)
        data["load"]["x"] = chunks[2].view(l0, l1)
        data["shunt"]["x"] = chunks[3].view(s0, s1)
        data["bus", "ac_line", "bus"]["edge_attr"] = chunks[4].view(e0, e1)
        data["bus", "transformer", "bus"]["edge_attr"] = chunks[5].view(t0, t1)

        # Keep limits from template; set branch_vals = edge_attr
        data["bus"]["v_lims"] = self.template["bus"]["v_lims"].clone()
        data["generator"]["p_lims"] = self.template["generator"]["p_lims"].clone()
        data["generator"]["q_lims"] = self.template["generator"]["q_lims"].clone()
        data["bus", "ac_line", "bus"]["branch_vals"] = data["bus", "ac_line", "bus"]["edge_attr"]
        data["bus", "transformer", "bus"]["branch_vals"] = data["bus", "transformer", "bus"]["edge_attr"]
        return data

    def flatten_output(self, out: dict) -> torch.Tensor:
        parts = [
            out["bus"].reshape(-1),
            out["generator"].reshape(-1),
            out["edge_preds"].reshape(-1),
        ]
        return torch.cat(parts, dim=0)

    def forward(self, x_flat: torch.Tensor) -> torch.Tensor:
        data = self.unflatten_input(x_flat)
        out = self.model(data)
        return self.flatten_output(out)


dataset = OPFData(
    split="train",
    case_name="pglib_opf_case14_ieee",
    num_groups=1,
    topological_perturbations=True,
    pre_transform="mean_zero_variance_one",
    root=os.path.join("data", "opfdata"),
)
sample = dataset[0]
canos = CANOS_OPF(dataset=dataset, hidden_dim=64, include_sent_messages=False, k_steps=3)
wrapper = VectorCanos(canos, sample)
x_flat = wrapper.flatten_input(sample)
y_flat = wrapper(x_flat)
print(f"input_dim={wrapper.input_dim}  output_dim={wrapper.output_dim}")
torch.save(wrapper, "vector-canos.pt")

Then we can do something like this:

import JuMP
import Ipopt
import MathOptInterface as MOI
import PythonCall
import MathOptAI as MOAI

# Requires torch and torch_geometric
PythonCall.pyimport("sys").path.append(pwd())
PythonCall.pyimport("vectorcanos")
predictor = MOAI.PytorchModel("vector-canos.pt")

N = 312
model = JuMP.Model(Ipopt.Optimizer)
JuMP.@variable(model, x[1:N], start = 0.5)
y, formulation = MOAI.add_predictor(model, predictor, x; vector_nonlinear_oracle = true)
xref = ones(N)
JuMP.@objective(model, Min, sum((x .- xref).^2))
# I just chose random numbers here, but we'd want to constrain a voltage or something
# to go over its limit.
JuMP.@constraint(model, y[2] >= 2.0)
JuMP.optimize!(model)

Our gray-box predictors can be used here, but the NN is non-differentiable, so I don't expect them to do very well at actually solving the optimization problem. I would like to be able to do this with the full-space formulation.

Robbybp avatar Dec 01 '25 01:12 Robbybp

Hmm. This is pretty complicated. I like the flatten/unflatten output and the oracle approach.

But it seems hard for us to parse that into a full-space data structure if it can be arbitrary Python code...

odow avatar Dec 02 '25 23:12 odow

Yeah, we'd probably have to find a way to transform or refactor this into a more structured form. OMLT seems to support sequential networks with GCNConv "layers". I wonder if we can rewrite these networks into something similar.

Robbybp avatar Dec 03 '25 04:12 Robbybp

Sequential networks with convolutional layers like GCNConv and Conv2d would cover most of the use cases I have encountered.

pulsipher avatar Dec 04 '25 14:12 pulsipher

@pulsipher do you have a small example? There are so many parameters for Conv2d, do we need them all?

odow avatar Dec 05 '25 00:12 odow

We do frequently use different settings for kernel size, in/out channels, padding, and stride. I haven't personally needed to mess with the dilation or groups. Here is a simple example from Flux that exemplifies a common serial CNN structure:

using Flux

cnn = Chain(
    Conv((5, 5), 1=>6, relu, pad = 2),
    MaxPool((2, 2)),
    Conv((5, 5), 6=>16, relu, pad = 2),
    MaxPool((2, 2)),
    Flux.flatten,
    Dense(256 => 120, relu),
    Dense(120 => 84, relu), 
    Dense(84 => 10),
)

In MathOptAI, this might look like:

model = Model()
@variable(model, x[1:28, 1:28])
y, _ = add_predictor(model, cnn, x)

pulsipher avatar Dec 05 '25 16:12 pulsipher

The matrix input is a problem. I really didn't want to have to get into this. It makes thinking about matrix shapes so much more complicated.

odow avatar Dec 05 '25 23:12 odow

I'm afraid the 2D convolutional layers fundamentally depend on matrix inputs, vectorizing leads to a loss in information (i.e., correlation/patterns with neighboring elements).

pulsipher avatar Dec 06 '25 03:12 pulsipher

Oh sure. But I was hoping we could have some reshaping so that at the MathOptAI level all inputs and outputs were vectors. I need to experiment with some things.

odow avatar Dec 06 '25 04:12 odow