pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

Best way to learn adjacency matrix for a graph?

Open christopher-beckham opened this issue 4 years ago • 6 comments

❓ Questions & Help

Hi,

Apologies if this has already been posted (though I spent a good half an hour trying to find a question like this). I am trying to figure out what the best way is to learn a parameterisation of a graph (i.e. have a neural net predict from some input: the nodes, their features, and the adjacency matrix).

I see that many of the graph conv layers take in a 2D tensor of edge indices, for edge_index, though we would not be able to backprop through this. It seems like either one would have to (a) define a fully-connected graph and instead infer the edge weights (where a weight of 0 between nodes (i,j) would effectively simulate two nodes not being connected), or if it's possible, directly pass in the adjacency matrix as one dense (n,n) matrix (though I assume this can only be binary, so that may also be problematic).

Any thoughts? Thanks in advance.

christopher-beckham avatar Jun 23 '20 05:06 christopher-beckham

The general consensus for an Graph-AE is to train against the dense adjacency matrix. However, you only need a dense output. In contrast, the input graph can be sparse. We have an example of this, see examples/autoencoder.py.

Note that, as you correctly mentioned, it is not possible to train against a sparse adjacency matrix. This stems mostly from the fact that you need a fixed output dimension with a fixed ordering, and that requirement cannot be fulfilled by sparse matrices.

However, there is some literature on this topic, e.g., Graph-RNN, which generates graphs in an auto-regressive fashion.

rusty1s avatar Jun 23 '20 07:06 rusty1s

Hi,

Thanks for your response!

In my case, I'd want to use the inferred outputs in a downstream manner (i.e., both the nodes' features and the adjacency matrix) and have that all be backproppable, e.g.:

input -> [mlp] -> {X, E} -> [GNNs] -> output

where E is the adjacency matrix and X are the node features. I assume that E however needs to be sparse in order for it to work with the GNNs later on in the network

In the case of the autoencoder its output (a dense adjacency matrix) just happens to also be the end of the network, which is convenient. In my case, it still seems like the most plausible option would be to fix the adjacency matrix to have the graph be fully-connected, and instead have the network infer edge weights instead. Let me know if you agree with this line of thinking.

Thanks again!

christopher-beckham avatar Jun 23 '20 17:06 christopher-beckham

Note that we also provide GNNs that can operate on dense input. For example, this is done in the DiffPool model. An alternative way would be to sparsify your dense adjacency matrix based on a user-defined threshold (similar to a ReLU activation):

edge_index = (adj > 0.5).nonzero().t()
edge_weight = adj[edge_index[0], edge_index[1]]

If you utilize both edge_index and edge_weight in your follow-up GNN, your graph generation is fully-trainable (except for the values you remove).

rusty1s avatar Jun 24 '20 06:06 rusty1s

Thanks! I will be sure to try it out

christopher-beckham avatar Jun 25 '20 16:06 christopher-beckham

The output of nonzero() breaks the computation graph, but the actual tensor still requires grad. And it still will when it gets indexed based on the indices returned by nonzero().

rusty1s avatar Jun 26 '20 18:06 rusty1s

I'm not sure if I understand the question correctly. But I think you do not have to use a dense adjacency matrix as input. Node features themselves are enough to predict edge connectivity (or weights). I did a small experiments these days. And it turns out pairwise concatenation of node features is suitable to generate "edge prediction".

After 2 or 3 epochs of training, the network can learn the exact adjacency matrix.

image

Here is the full code.

import os.path as osp
import random

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as gnn
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch as PygBatch
from torch_geometric.data import Data as PygData

node_num = 20
feat_dim = 8 # The feature dimension cannot be too small.
device = "cuda" if torch.cuda.is_available() else "cpu"


class TestModel(nn.Module):
    def __init__(self, node_num: int, feat_dim: int) -> None:
        super().__init__()
        self._node_num = node_num
        self._feat_dim = feat_dim

        self._x_em = nn.Embedding(num_embeddings=node_num, embedding_dim=feat_dim)

        self._gc_list = nn.ModuleList(
            [
                gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
                gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
                gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
            ]
        )
        self._last_gc = gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim)

        # MLP fits as a "Combiner". Too shallow MLP would give a bad output.
        self._mlp = nn.Sequential(
            nn.Linear(in_features=2 * feat_dim, out_features=2 * feat_dim),
            nn.LeakyReLU(),
            nn.Linear(in_features=2 * feat_dim, out_features=feat_dim),
            nn.LeakyReLU(),
            nn.Linear(in_features=feat_dim, out_features=feat_dim),
            nn.LeakyReLU(),
            nn.Linear(in_features=feat_dim, out_features=1),
        )

    def forward(self, x, edge_index):
        n = self._node_num
        f = self._feat_dim

        x = self._x_em(x)

        # The `+x` part is important. It helps the network capture features during different convolution stages.
        for conv in self._gc_list:
            x = F.leaky_relu(conv(x, edge_index)) + x
        x = self._last_gc(x, edge_index) + x
        x = x.view(-1, n, f)  # [B, N, F]

        # Pairwise concatenation
        idx_pairs = torch.cartesian_prod(
            torch.arange(x.shape[-2]), torch.arange(x.shape[-2])
        )
        x = x[:, idx_pairs]  # [B, N * N, 2, F]
        x = x.view(-1, n, n, 2 * f)
        x = self._mlp(x)  # [B, N, N, 1]
        x = x.view(-1, n, n)
        x = (x + x.transpose(-1, -2)) / 2
        return x


loss_fn = nn.SmoothL1Loss()
net = TestModel(node_num=node_num, feat_dim=feat_dim).to(device=device)

optimizer = optim.RAdam(net.parameters(), lr=0.005)

log_dir = osp.dirname(osp.abspath(__file__))
log_dir = osp.join(log_dir, "torch_runs")
log_dir = osp.join(log_dir, "adj_learner")
summary_writer = SummaryWriter(log_dir=log_dir)

x = torch.arange(node_num)
dataset = []

for _ in range(5000):
    edge_index = []
    adj_mat = torch.zeros(node_num, node_num, dtype=torch.float)
    for _ in range(10):
        u, v = -1, -1
        while u == v:
            u = random.choice(range(node_num))
            v = random.choice(range(node_num))
        edge_index.append((u, v))
        edge_index.append((v, u))
        adj_mat[u][v] = 1.0
        adj_mat[v][u] = 1.0
    edge_index = torch.tensor(edge_index, dtype=torch.long).T.contiguous()
    dataset.append((PygData(x=x, edge_index=edge_index), adj_mat))

test_data = random.sample(dataset, k=2)


def show_status(epoch_id: int = None):
    for tag, (data, adj_mat) in enumerate(test_data):
        batch = PygBatch.from_data_list([data]).to(device)
        with torch.no_grad():
            out = net(batch.x, batch.edge_index)

        out = out.cpu().squeeze()

        fig, (ax1, ax2) = plt.subplots(ncols=2)
        im1 = ax1.matshow(adj_mat, interpolation=None)
        im2 = ax2.matshow(out, interpolation=None)
        ax1.set_title("Adjacency Matrix")
        ax2.set_title("Fitted Matrix")
        fig.colorbar(im1, ax=ax1)
        fig.colorbar(im2, ax=ax2)
        summary_writer.add_figure(f"Example status {tag}", fig, epoch_id)


running_loss = 0.0
obs_period = 200
iter_per_epoch = 2000
for epoch_id in range(100):
    print(f"Epoch :{epoch_id+1}")
    net.train()
    batch_size = 32
    for it in range(iter_per_epoch):
        data_list, adj_mat_batch = zip(*random.sample(dataset, k=batch_size))
        adj_mat_batch = torch.stack(adj_mat_batch).to(device)
        batch = PygBatch.from_data_list(data_list).to(device)

        out = net(batch.x, batch.edge_index)
        optimizer.zero_grad()
        loss = loss_fn(out, adj_mat_batch)
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()
        running_loss += cur_loss

        summary_writer.add_scalar(
            "Training loss", cur_loss, epoch_id * iter_per_epoch + it
        )
        if (it + 1) % obs_period == 0:
            running_loss /= obs_period
            print(f"    [{it+1:4}] running loss: {running_loss:0.4f}")
            running_loss = 0.0
    net.train(False)
    show_status(epoch_id)

LinHeLurking avatar Sep 09 '22 02:09 LinHeLurking