pytorch_geometric
pytorch_geometric copied to clipboard
Best way to learn adjacency matrix for a graph?
❓ Questions & Help
Hi,
Apologies if this has already been posted (though I spent a good half an hour trying to find a question like this). I am trying to figure out what the best way is to learn a parameterisation of a graph (i.e. have a neural net predict from some input: the nodes, their features, and the adjacency matrix).
I see that many of the graph conv layers take in a 2D tensor of edge indices, for edge_index
, though we would not be able to backprop through this. It seems like either one would have to (a) define a fully-connected graph and instead infer the edge weights (where a weight of 0 between nodes (i,j) would effectively simulate two nodes not being connected), or if it's possible, directly pass in the adjacency matrix as one dense (n,n) matrix (though I assume this can only be binary, so that may also be problematic).
Any thoughts? Thanks in advance.
The general consensus for an Graph-AE is to train against the dense adjacency matrix. However, you only need a dense output. In contrast, the input graph can be sparse. We have an example of this, see examples/autoencoder.py
.
Note that, as you correctly mentioned, it is not possible to train against a sparse adjacency matrix. This stems mostly from the fact that you need a fixed output dimension with a fixed ordering, and that requirement cannot be fulfilled by sparse matrices.
However, there is some literature on this topic, e.g., Graph-RNN, which generates graphs in an auto-regressive fashion.
Hi,
Thanks for your response!
In my case, I'd want to use the inferred outputs in a downstream manner (i.e., both the nodes' features and the adjacency matrix) and have that all be backproppable, e.g.:
input -> [mlp] -> {X, E} -> [GNNs] -> output
where E
is the adjacency matrix and X
are the node features. I assume that E
however needs to be sparse in order for it to work with the GNNs later on in the network
In the case of the autoencoder its output (a dense adjacency matrix) just happens to also be the end of the network, which is convenient. In my case, it still seems like the most plausible option would be to fix the adjacency matrix to have the graph be fully-connected, and instead have the network infer edge weights instead. Let me know if you agree with this line of thinking.
Thanks again!
Note that we also provide GNNs that can operate on dense input. For example, this is done in the DiffPool
model. An alternative way would be to sparsify your dense adjacency matrix based on a user-defined threshold (similar to a ReLU activation):
edge_index = (adj > 0.5).nonzero().t()
edge_weight = adj[edge_index[0], edge_index[1]]
If you utilize both edge_index
and edge_weight
in your follow-up GNN, your graph generation is fully-trainable (except for the values you remove).
Thanks! I will be sure to try it out
The output of nonzero()
breaks the computation graph, but the actual tensor still requires grad. And it still will when it gets indexed based on the indices returned by nonzero()
.
I'm not sure if I understand the question correctly. But I think you do not have to use a dense adjacency matrix as input. Node features themselves are enough to predict edge connectivity (or weights). I did a small experiments these days. And it turns out pairwise concatenation of node features is suitable to generate "edge prediction".
After 2 or 3 epochs of training, the network can learn the exact adjacency matrix.
Here is the full code.
import os.path as osp
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as gnn
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch as PygBatch
from torch_geometric.data import Data as PygData
node_num = 20
feat_dim = 8 # The feature dimension cannot be too small.
device = "cuda" if torch.cuda.is_available() else "cpu"
class TestModel(nn.Module):
def __init__(self, node_num: int, feat_dim: int) -> None:
super().__init__()
self._node_num = node_num
self._feat_dim = feat_dim
self._x_em = nn.Embedding(num_embeddings=node_num, embedding_dim=feat_dim)
self._gc_list = nn.ModuleList(
[
gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim),
]
)
self._last_gc = gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim)
# MLP fits as a "Combiner". Too shallow MLP would give a bad output.
self._mlp = nn.Sequential(
nn.Linear(in_features=2 * feat_dim, out_features=2 * feat_dim),
nn.LeakyReLU(),
nn.Linear(in_features=2 * feat_dim, out_features=feat_dim),
nn.LeakyReLU(),
nn.Linear(in_features=feat_dim, out_features=feat_dim),
nn.LeakyReLU(),
nn.Linear(in_features=feat_dim, out_features=1),
)
def forward(self, x, edge_index):
n = self._node_num
f = self._feat_dim
x = self._x_em(x)
# The `+x` part is important. It helps the network capture features during different convolution stages.
for conv in self._gc_list:
x = F.leaky_relu(conv(x, edge_index)) + x
x = self._last_gc(x, edge_index) + x
x = x.view(-1, n, f) # [B, N, F]
# Pairwise concatenation
idx_pairs = torch.cartesian_prod(
torch.arange(x.shape[-2]), torch.arange(x.shape[-2])
)
x = x[:, idx_pairs] # [B, N * N, 2, F]
x = x.view(-1, n, n, 2 * f)
x = self._mlp(x) # [B, N, N, 1]
x = x.view(-1, n, n)
x = (x + x.transpose(-1, -2)) / 2
return x
loss_fn = nn.SmoothL1Loss()
net = TestModel(node_num=node_num, feat_dim=feat_dim).to(device=device)
optimizer = optim.RAdam(net.parameters(), lr=0.005)
log_dir = osp.dirname(osp.abspath(__file__))
log_dir = osp.join(log_dir, "torch_runs")
log_dir = osp.join(log_dir, "adj_learner")
summary_writer = SummaryWriter(log_dir=log_dir)
x = torch.arange(node_num)
dataset = []
for _ in range(5000):
edge_index = []
adj_mat = torch.zeros(node_num, node_num, dtype=torch.float)
for _ in range(10):
u, v = -1, -1
while u == v:
u = random.choice(range(node_num))
v = random.choice(range(node_num))
edge_index.append((u, v))
edge_index.append((v, u))
adj_mat[u][v] = 1.0
adj_mat[v][u] = 1.0
edge_index = torch.tensor(edge_index, dtype=torch.long).T.contiguous()
dataset.append((PygData(x=x, edge_index=edge_index), adj_mat))
test_data = random.sample(dataset, k=2)
def show_status(epoch_id: int = None):
for tag, (data, adj_mat) in enumerate(test_data):
batch = PygBatch.from_data_list([data]).to(device)
with torch.no_grad():
out = net(batch.x, batch.edge_index)
out = out.cpu().squeeze()
fig, (ax1, ax2) = plt.subplots(ncols=2)
im1 = ax1.matshow(adj_mat, interpolation=None)
im2 = ax2.matshow(out, interpolation=None)
ax1.set_title("Adjacency Matrix")
ax2.set_title("Fitted Matrix")
fig.colorbar(im1, ax=ax1)
fig.colorbar(im2, ax=ax2)
summary_writer.add_figure(f"Example status {tag}", fig, epoch_id)
running_loss = 0.0
obs_period = 200
iter_per_epoch = 2000
for epoch_id in range(100):
print(f"Epoch :{epoch_id+1}")
net.train()
batch_size = 32
for it in range(iter_per_epoch):
data_list, adj_mat_batch = zip(*random.sample(dataset, k=batch_size))
adj_mat_batch = torch.stack(adj_mat_batch).to(device)
batch = PygBatch.from_data_list(data_list).to(device)
out = net(batch.x, batch.edge_index)
optimizer.zero_grad()
loss = loss_fn(out, adj_mat_batch)
loss.backward()
cur_loss = loss.item()
optimizer.step()
running_loss += cur_loss
summary_writer.add_scalar(
"Training loss", cur_loss, epoch_id * iter_per_epoch + it
)
if (it + 1) % obs_period == 0:
running_loss /= obs_period
print(f" [{it+1:4}] running loss: {running_loss:0.4f}")
running_loss = 0.0
net.train(False)
show_status(epoch_id)
I'm not sure if I understand the question correctly. But I think you do not have to use a dense adjacency matrix as input. Node features themselves are enough to predict edge connectivity (or weights). I did a small experiments these days. And it turns out pairwise concatenation of node features is suitable to generate "edge prediction".
After 2 or 3 epochs of training, the network can learn the exact adjacency matrix.
Here is the full code.
import os.path as osp import random import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch_geometric.nn as gnn from torch.utils.tensorboard import SummaryWriter from torch_geometric.data import Batch as PygBatch from torch_geometric.data import Data as PygData node_num = 20 feat_dim = 8 # The feature dimension cannot be too small. device = "cuda" if torch.cuda.is_available() else "cpu" class TestModel(nn.Module): def __init__(self, node_num: int, feat_dim: int) -> None: super().__init__() self._node_num = node_num self._feat_dim = feat_dim self._x_em = nn.Embedding(num_embeddings=node_num, embedding_dim=feat_dim) self._gc_list = nn.ModuleList( [ gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim), gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim), gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim), ] ) self._last_gc = gnn.GCNConv(in_channels=feat_dim, out_channels=feat_dim) # MLP fits as a "Combiner". Too shallow MLP would give a bad output. self._mlp = nn.Sequential( nn.Linear(in_features=2 * feat_dim, out_features=2 * feat_dim), nn.LeakyReLU(), nn.Linear(in_features=2 * feat_dim, out_features=feat_dim), nn.LeakyReLU(), nn.Linear(in_features=feat_dim, out_features=feat_dim), nn.LeakyReLU(), nn.Linear(in_features=feat_dim, out_features=1), ) def forward(self, x, edge_index): n = self._node_num f = self._feat_dim x = self._x_em(x) # The `+x` part is important. It helps the network capture features during different convolution stages. for conv in self._gc_list: x = F.leaky_relu(conv(x, edge_index)) + x x = self._last_gc(x, edge_index) + x x = x.view(-1, n, f) # [B, N, F] # Pairwise concatenation idx_pairs = torch.cartesian_prod( torch.arange(x.shape[-2]), torch.arange(x.shape[-2]) ) x = x[:, idx_pairs] # [B, N * N, 2, F] x = x.view(-1, n, n, 2 * f) x = self._mlp(x) # [B, N, N, 1] x = x.view(-1, n, n) x = (x + x.transpose(-1, -2)) / 2 return x loss_fn = nn.SmoothL1Loss() net = TestModel(node_num=node_num, feat_dim=feat_dim).to(device=device) optimizer = optim.RAdam(net.parameters(), lr=0.005) log_dir = osp.dirname(osp.abspath(__file__)) log_dir = osp.join(log_dir, "torch_runs") log_dir = osp.join(log_dir, "adj_learner") summary_writer = SummaryWriter(log_dir=log_dir) x = torch.arange(node_num) dataset = [] for _ in range(5000): edge_index = [] adj_mat = torch.zeros(node_num, node_num, dtype=torch.float) for _ in range(10): u, v = -1, -1 while u == v: u = random.choice(range(node_num)) v = random.choice(range(node_num)) edge_index.append((u, v)) edge_index.append((v, u)) adj_mat[u][v] = 1.0 adj_mat[v][u] = 1.0 edge_index = torch.tensor(edge_index, dtype=torch.long).T.contiguous() dataset.append((PygData(x=x, edge_index=edge_index), adj_mat)) test_data = random.sample(dataset, k=2) def show_status(epoch_id: int = None): for tag, (data, adj_mat) in enumerate(test_data): batch = PygBatch.from_data_list([data]).to(device) with torch.no_grad(): out = net(batch.x, batch.edge_index) out = out.cpu().squeeze() fig, (ax1, ax2) = plt.subplots(ncols=2) im1 = ax1.matshow(adj_mat, interpolation=None) im2 = ax2.matshow(out, interpolation=None) ax1.set_title("Adjacency Matrix") ax2.set_title("Fitted Matrix") fig.colorbar(im1, ax=ax1) fig.colorbar(im2, ax=ax2) summary_writer.add_figure(f"Example status {tag}", fig, epoch_id) running_loss = 0.0 obs_period = 200 iter_per_epoch = 2000 for epoch_id in range(100): print(f"Epoch :{epoch_id+1}") net.train() batch_size = 32 for it in range(iter_per_epoch): data_list, adj_mat_batch = zip(*random.sample(dataset, k=batch_size)) adj_mat_batch = torch.stack(adj_mat_batch).to(device) batch = PygBatch.from_data_list(data_list).to(device) out = net(batch.x, batch.edge_index) optimizer.zero_grad() loss = loss_fn(out, adj_mat_batch) loss.backward() cur_loss = loss.item() optimizer.step() running_loss += cur_loss summary_writer.add_scalar( "Training loss", cur_loss, epoch_id * iter_per_epoch + it ) if (it + 1) % obs_period == 0: running_loss /= obs_period print(f" [{it+1:4}] running loss: {running_loss:0.4f}") running_loss = 0.0 net.train(False) show_status(epoch_id)
Your node features are directly related to the ordinal, and nn.Embedding, and the adjacency matrix is randomly generated. Why can this represent “Node features themselves are enough to predict edge connectivity”? This result seems to only indicate that the neural network can fit the adjacency matrix, but the node features appear to be completely useless.Because the nn.Embedding is trainable.