pytorch_geometric
pytorch_geometric copied to clipboard
[WIP] `EquivariantConv`
Equivariant conv as proposed in this paper. Two things to note
- Equivariant conv updates a nodes embedding as well as position. So the functions
message,aggregateandupdatereturn a tuple of tensors. - jit seems to have some issues as
propogatereturns a Tuple of tensors and not a tensor. One potential workaround would be to haveupdatereturn a tensor for node embedding, the positional update could be saved in a class variable. Let me know if that change needs to be made or if you think of a better fix.
Codecov Report
Merging #2824 (a3118b4) into master (dc8f503) will increase coverage by
0.09%. The diff coverage is100.00%.
@@ Coverage Diff @@
## master #2824 +/- ##
==========================================
+ Coverage 84.23% 84.33% +0.09%
==========================================
Files 209 210 +1
Lines 9524 9579 +55
==========================================
+ Hits 8023 8078 +55
Misses 1501 1501
| Impacted Files | Coverage Δ | |
|---|---|---|
| torch_geometric/nn/conv/__init__.py | 100.00% <100.00%> (ø) |
|
| torch_geometric/nn/conv/equivariant_conv.py | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact),ø = not affected,? = missing dataPowered by Codecov. Last update 85c8b78...a3118b4. Read the comment docs.
@rusty1s thanks for updating jittable in message passing(673f94729b6a520b994699da5aa8dd3d1a1f670b). With that equivariant conv is jittable.
HI @wsad1 , thanks for your implementation on the EGNN-layer implementation. This implementation updates the positional features based on the local neighbourhood only, right? https://github.com/rusty1s/pytorch_geometric/blob/a3118b4406e12c817e20c3efc8210b228e886d76/torch_geometric/nn/conv/equivariant_conv.py#L111-L145 In the original paper, the positions x_i^{l+1} are computed by iterating over all nodes in the graph. I guess this makes it hard to combine withing the MessagePassing framework, as the index provided is based on the given adjacency matrix.
@tuanle618 , Thanks for bringing this up. You are absolutely right, the positional embedding pos should be updated based on all nodes in the graph, the current implementation updates it based on node neighbours.
One way to implement this would be to
- create a
fully_connected_edge_indexfor each graph. And pass this topropogateasedge_index, and pass the originaledge_indexas another argumentoriginal_index. - in
messagethex_i,x_j,pos_i,pos_jwould be created based on the fully connected edge index. - in
aggregate,msgneeds to be aggregated based onoriginal_index[,1]andposbased onindex.
The above approach would not support edge_attr. And its not very straightforward or memory efficient, so i might have to think about it more.
Hi @wsad1 , thanks a lot for your quick response and implementation suggestions. I would like to support you in this manner and have created some code for the steps (1) and (2).
For step 3, however, I am right now not certain how we would provide the edge_index and fc_edge_index for the self.aggregate function, as the final goal is actually to just call self.propagate ones, right?
For supporting edge_attr, I was thinking to zero-pad the edge_attr tensor, along dim=0.
In general, my steps currently include:
Lets say the number of edges is E=edge_index.size(1), and the fully-connected edge_index is also bounded by the number of nodes in our batch, i.e. E_fc=fc_edge_index.size(1).
- (a) create fully-connected edge-index from the batch, i.e.,
fc_edge_index - (b) get "fake" indices within
fc_edge_indexthat are not present in the "true"edge_index, and updatefc_edge_indexaccording to. That meansfc_edge_index.size(1) = E_fc - E - (c) now concatenate along dimension
0, so we "match" the ordering for theedge_attr, i.e.fc_edge_index=torch.cat([edge_index, fc_edge_index], dim=-1)andedge_attr=torch.cat([edge_attr, torch.zeros(size=(E_fc-E, edge_attr.size(-1))], device=x.device)
Do step (3) from you, by providing edge_index and fc_edge_index, where fc_edge_index is used to update the positional embeddings,
and edge_index just to gather all incoming messages to update the node embeddings.
import torch
from torch_scatter import scatter_add
from torch_geometric.data import Batch
# assumes data is of type `torch_geometric.data.batch.Batch`
x, batch, ptr = data.x, data.batch, data.ptr
batch_size = batch.max().item() + 1
edge_index, edge_attr = data.edge_index, data.edge_attr
def get_fully_connected_edges(n_nodes: int, add_self_loops: bool = False):
rows, cols = [], []
for i in range(n_nodes):
for j in range(n_nodes):
if i != j or (i == j and add_self_loops):
rows.append(i)
cols.append(j)
edges = [rows, cols]
edges = torch.tensor(edges, dtype=torch.long).contiguous()
return edges
batch_num_nodes = scatter_add(src=batch.new_ones(x.size(0)), index=batch, dim=0, dim_size=batch_size)
edge_index = data.edge_index
fc_edge_index = torch.cat([get_fully_connected_edges(n) + p for n, p in zip(batch_num_nodes, ptr)], dim=-1)
# a memory-inefficient and maybe slower version
# adjs = torch.block_diag(*[torch.ones(n, n).fill_diagonal_(0.0) for n in batch_num_nodes]).nonzero().t().contiguous()
# torch.allclose(fc_edge_index, adjs)
# find positions of true edge_index
source, target = source, target = edge_index[0].cpu().numpy().tolist(), edge_index[1].cpu().numpy().tolist()
source_target_to_edge_idx = {str([s, t]): i for s, t, i in zip(source, target, range(len(source)))}
edge_idx_to_source_target = {v: k for k, v in source_target_to_edge_idx.items()}
# positions of fake edge_index
source_fc, target_fc = fc_edge_index[0].cpu().numpy().tolist(), fc_edge_index[1].cpu().numpy().tolist()
source_target_to_fc_edge_idx = {str([s, t]): i for s, t, i in zip(source_fc, target_fc, range(len(source_fc)))}
fc_edge_idx_to_source_target = {v: k for k, v in source_target_to_fc_edge_idx.items()}
fake_edges = [s for s in source_target_to_fc_edge_idx.keys() if s not in source_target_to_edge_idx.keys()]
fake_edges_ids = [source_target_to_fc_edge_idx[k] for k in fake_edges]
E_fc = fc_edge_index.shape[1]
E = edge_index.shape[1]
assert len(fake_edges) == E_fc - E
fake_edge_index = fc_edge_index.t()[fake_edges_ids].t()
fake_edge_attr = torch.zeros(size=(fake_edge_index.size(1), edge_attr.size(-1)),
device=x.device)
all_edge_index = torch.cat([edge_index, fake_edge_index], dim=-1)
all_edge_attr = torch.cat([edge_attr, fake_edge_attr], dim=0)
I've modified your implemented version of the EquivariantConv, @wsad1 .
Right now, self.propagate is called twice.
The first time, is when all fully-connected messages are constructed based on the fc_edge_index. After that call, intermediate messages are saved internally as tuples in self.__calculated_msgs.
The second time, is called when the true edge_index is input. Right now, two unnecessary aggregation steps are done for the x and pos for the first and second time, respectively.
Additionally, I removed the add_self_loops argument, as
edge-features for the self-loop could be included with zero-padded tensors, but the self-message does not make much sense in my opinion when constructing m_ij, as (a) the positional distance is 0, and hence the input for the local_nn is then just a concatenation of the same value, as well as zero-vectors for distance and edge_attr, respectively.
Find below the code, I've slightly modified from your version:
The Conv:
from typing import Optional, Callable, Tuple
from torch_geometric.typing import OptTensor, Adj
import torch
from torch import Tensor
from torch.nn import Linear
from torch_scatter import scatter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from torch_geometric.nn.inits import reset
class EquivariantConv(MessagePassing):
r"""The Equivariant graph neural network operator form the
`"E(n) Equivariant Graph Neural Networks"
<https://arxiv.org/pdf/2102.09844.pdf>`_ paper.
.. math::
\mathbf{m}_{ij}=h_{\mathbf{\Theta}}(\mathbf{x}_i,\mathbf{x}_j,\|
{\mathbf{pos}_i-\mathbf{pos}_j}\|^2_2,\mathbf{e}_{ij})
\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}}(\mathbf{x}_i,
\sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij})
\mathbf{vel}^{\prime}_i = \phi_{\mathbf{\Theta}}(\mathbf{x}_i)\mathbf
{vel}_i + \frac{1}{|\mathcal{N}(i)|}\sum_{j \in\mathcal{N}(i)}
(\mathbf{pos}_i-\mathbf{pos}_j)
\rho_{\mathbf{\Theta}}(\mathbf{m}_{ij})
\mathbf{pos}^{\prime}_i = \mathbf{pos}_i + \mathbf{vel}_i
where :math:`\gamma_{\mathbf{\Theta}}`,
:math:`h_{\mathbf{\Theta}}`, :math:`\rho_{\mathbf{\Theta}}`
and :math:`\phi_{\mathbf{\Theta}}` denote neural
networks, *.i.e.* MLPs. :math:`\mathbf{P} \in \mathbb{R}^{N \times D}`
and :math:`\mathbf{V} \in \mathbb{R}^{N \times D}`
defines the position and velocity of each point respectively.
Args:
local_nn (torch.nn.Module, optional): A neural network
:math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x`,
sqaured distance :math:`\|{\mathbf{pos}_i-\mathbf{pos}_j}\|^2_2`
and edge_features :obj:`edge_attr`
of shape :obj:`[-1, 2*in_channels + 1 +edge_dim]`
to shape :obj:`[-1, hidden_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
pos_nn (torch.nn.Module,optinal): A neural network
:math:`\rho_{\mathbf{\Theta}}` that
maps message :obj:`m` of shape
:obj:`[-1, hidden_channels]`,
to shape :obj:`[-1, 1]`, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
vel_nn (torch.nn.Module,optional): A neural network
:math:`\phi_{\mathbf{\Theta}}` that
maps node featues :obj:`x` of shape :obj:`[-1, in_channels]`,
to shape :obj:`[-1, 1]`, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
global_nn (torch.nn.Module, optional): A neural network
:math:`\gamma_{\mathbf{\Theta}}` that maps
message :obj:`m` after aggregation
and node features :obj:`x` of shape
:obj:`[-1, hidden_channels + in_channels]`
to shape :obj:`[-1, out_channels]`, *e.g.*, defined by
:class:`torch.nn.Sequential`. (default: :obj:`None`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
aggr (string, optional): The operator used to aggregate message
:obj:`m` (:obj:`"add"`, :obj:`"mean"`).
(default: :obj:`"mean"`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, local_nn: Optional[Callable] = None,
pos_nn: Optional[Callable] = None,
vel_nn: Optional[Callable] = None,
global_nn: Optional[Callable] = None,
aggr="mean", **kwargs):
super(EquivariantConv, self).__init__(aggr=aggr, **kwargs)
self.local_nn = local_nn
self.pos_nn = pos_nn
self.vel_nn = vel_nn
self.global_nn = global_nn
self.add_self_loops = add_self_loops
self.reset_parameters()
def reset_parameters(self):
reset(self.local_nn)
reset(self.pos_nn)
reset(self.vel_nn)
reset(self.global_nn)
def forward(self, x: OptTensor,
pos: Tensor,
edge_index: Adj,
fc_edge_index: Adj,
vel: OptTensor = None,
edge_attr: OptTensor = None
) -> Tuple[Tensor, Tuple[Tensor, OptTensor]]:
""""""
self.__calculated_msgs = (None, None)
self.__E = edge_index.size(1)
# propagate_type: (x: OptTensor, pos: Tensor, edge_attr: OptTensor) -> Tuple[Tensor,Tensor] # noqa
_, out_pos = self.propagate(fc_edge_index, x=x, pos=pos,
edge_attr=edge_attr, size=None)
out_x, _ = self.propagate(edge_index, x=x, pos=pos,
edge_attr=edge_attr, size=None)
out_x = out_x if x is None else torch.cat([x, out_x], dim=1)
if self.global_nn is not None:
out_x = self.global_nn(out_x)
if vel is None:
out_pos += pos
out_vel = None
else:
out_vel = (vel if self.vel_nn is None or x is None else
self.vel_nn(x) * vel) + out_pos
out_pos = pos + out_vel
self.__calculated_msgs = (None, None)
return (out_x, (out_pos, out_vel))
def message(self, x_i: OptTensor, x_j: OptTensor, pos_i: Tensor,
pos_j: Tensor,
edge_attr: OptTensor = None) -> Tuple[Tensor, Tensor]:
# only do this calculation once
if self.__calculated_msgs[0] is None and self.__calculated_msgs[1] is None:
msg = torch.sum((pos_i - pos_j).square(), dim=1, keepdim=True)
msg = msg if x_j is None else torch.cat([x_j, msg], dim=1)
msg = msg if x_i is None else torch.cat([x_i, msg], dim=1)
msg = msg if edge_attr is None else torch.cat([msg, edge_attr], dim=1)
msg = msg if self.local_nn is None else self.local_nn(msg)
pos_msg = ((pos_i - pos_j) if self.pos_nn is None else
(pos_i - pos_j) * self.pos_nn(msg))
self.__calculated_msgs = (msg, pos_msg)
return (msg, pos_msg)
else:
return (self.__calculated_msgs[0][:self.__E], self.__calculated_msgs[1][:self.__E])
def aggregate(self, inputs: Tuple[Tensor, Tensor],
index: Tensor) -> Tuple[Tensor, Tensor]:
return (scatter(inputs[0], index, 0, reduce=self.aggr),
scatter(inputs[1], index, 0, reduce="mean"))
def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
return inputs
def __repr__(self):
return ("{}(local_nn={}, pos_nn={},"
" vel_nn={},"
" global_nn={})").format(self.__class__.__name__,
self.local_nn, self.pos_nn,
self.vel_nn, self.global_nn)
A minimal example using networkx generated random graphs without the velocities.
from torch_geometric.utils import from_networkx
import networkx as nx
from torch_geometric.data import DataLoader
from torch_scatter import scatter_add
def get_fully_connected_get_edges(n_nodes: int, add_self_loops: bool = False):
rows, cols = [], []
for i in range(n_nodes):
for j in range(n_nodes):
if i != j or (i == j and add_self_loops):
rows.append(i)
cols.append(j)
edges = [rows, cols]
edges = torch.tensor(edges, dtype=torch.long).contiguous()
return edges
def get_fc_edge_index(batch_num_nodes: list, ptr: torch.Tensor,
edge_index: torch.Tensor, edge_attr: torch.Tensor) -> (torch.Tensor, torch.Tensor):
fc_edge_index = torch.cat([get_fully_connected_get_edges(n) + p for n, p in zip(batch_num_nodes, ptr)], dim=-1)
source, target = edge_index[0].cpu().numpy().tolist(), edge_index[1].cpu().numpy().tolist()
source_target_to_edge_idx = {str([s, t]): i for s, t, i in zip(source, target, range(len(source)))}
# edge_idx_to_source_target = {v: k for k, v in source_target_to_edge_idx.items()}
# positions of fake edge_index
source_fc, target_fc = fc_edge_index[0].cpu().numpy().tolist(), fc_edge_index[1].cpu().numpy().tolist()
source_target_to_fc_edge_idx = {str([s, t]): i for s, t, i in zip(source_fc, target_fc, range(len(source_fc)))}
# fc_edge_idx_to_source_target = {v: k for k, v in source_target_to_fc_edge_idx.items()}
fake_edges = [s for s in source_target_to_fc_edge_idx.keys() if s not in source_target_to_edge_idx.keys()]
fake_edges_ids = [source_target_to_fc_edge_idx[k] for k in fake_edges]
E_fc = fc_edge_index.shape[1]
E = edge_index.shape[1]
assert len(fake_edges) == E_fc - E
fake_edge_index = fc_edge_index.t()[fake_edges_ids].t()
fake_edge_attr = torch.zeros(size=(fake_edge_index.size(1), edge_attr.size(-1)),
device=x.device)
all_edge_index = torch.cat([edge_index, fake_edge_index], dim=-1)
all_edge_attr = torch.cat([edge_attr, fake_edge_attr], dim=0)
return all_edge_index, all_edge_attr
def create_random_graph_pyg(n: int, seed: int):
G = nx.random_geometric_graph(n=n, radius=0.125, dim=3, seed=seed)
data = from_networkx(G)
data.x = torch.randn(data.num_nodes, 16)
data.edge_attr = torch.randn(data.edge_index.size(1), 8)
return data
batch_size = 16
max_num_nodes = 100
seed = 42
batch_num_nodes = torch.randint(low=20, high=max_num_nodes, size=(batch_size, ), dtype=torch.long)
datalist = [create_random_graph_pyg(n, seed + i) for i, n in enumerate(batch_num_nodes)]
loader = DataLoader(datalist, 16)
data = next(iter(loader))
x, pos, batch, ptr, edge_index, edge_attr = data.x, data.pos, data.batch, data.ptr, data.edge_index, data.edge_attr
batch_num_nodes = scatter_add(src=batch.new_ones(x.size(0)), index=batch, dim=0, dim_size=batch_size).cpu().numpy().tolist()
fc_edge_index, fc_edge_attr = get_fc_edge_index(batch_num_nodes=batch_num_nodes, ptr=ptr,
edge_index=edge_index, edge_attr=edge_attr)
node_in_channels = 16
edge_in_channels = 8
pos_in_channels = 3
local_nn = Linear(2 * node_in_channels + 1 + edge_in_channels, node_in_channels, bias=False)
pos_nn = Linear(node_in_channels, 1, bias=True)
global_nn = Linear(2 * node_in_channels, node_in_channels, bias=True)
conv = EquivariantConv(local_nn=local_nn, pos_nn=pos_nn, global_nn=global_nn)
x_out, (pos_out, _) = conv(x=x, pos=pos, edge_index=edge_index, fc_edge_index=fc_edge_index, edge_attr=fc_edge_attr)
# test without edge-attrs
node_in_channels = 16
edge_in_channels = 0
pos_in_channels = 3
local_nn = Linear(2 * node_in_channels + 1 + edge_in_channels, node_in_channels, bias=False)
pos_nn = Linear(node_in_channels, 1, bias=True)
global_nn = Linear(2 * node_in_channels, node_in_channels, bias=True)
conv = EquivariantConv(local_nn=local_nn, pos_nn=pos_nn, global_nn=global_nn)
x_out, (pos_out, _) = conv(x=x, pos=pos, edge_index=edge_index, fc_edge_index=fc_edge_index, edge_attr=None)
@tuanle618 , firstly I would be happy to collaborate with you on this. Please feel free to send PRs to this branch, with your code edits.
Second, appreciate your effort to fix EquivariantConv. Some questions and thoughts.
1.I think propogate could be called just once, let me know if I am missing something here. So aggregate takes any argument passed to propogate. I believe something like this should work.
out_x, out_pos = self.propagate(edge_index = fc_edge_index, x=x, pos=pos,
edge_attr=edge_attr, orig_index = edge_index[1], size=None)
def aggregate(self, inputs: Tuple[Tensor, Tensor],
index: Tensor, orig_index: Tensor) -> Tuple[Tensor, Tensor]:
return (scatter(inputs[0], orig_index, 0, reduce=self.aggr), # aggregate on original edges
scatter(inputs[1], index, 0, reduce="mean")) # aggregate on fc edges
get_fully_connected_get_edgescould be simplified to
def get_fully_connected_get_edges(n_nodes: int, add_self_loops: bool = False):
edge_index = torch.cartesian_prod(torch.arange(n_nodes),torch.arange(n_nodes)).T
if not add_self_loops:
edge_index = edge_index[edge_index[0]!=edge_index[1]]
Thanks for your suggestions @wsad1 .
I'm gonna make a PR soon on your forked repository to the enn branch. Will need to add some further tests, to make sure the aggregation on node-embeddings x (based on edge_index) and pos (based on fc_edge_index) are also correct. To your step (1) - I tried to manage to just use one function call to self.propagate using your
out_x, out_pos = self.propagate(edge_index = fc_edge_index, x=x, pos=pos,
edge_attr=edge_attr, orig_index = edge_index[1], size=None)
def aggregate(self, inputs: Tuple[Tensor, Tensor],
index: Tensor, orig_index: Tensor) -> Tuple[Tensor, Tensor]:
return (scatter(inputs[0], orig_index, 0, reduce=self.aggr), # aggregate on original edges
scatter(inputs[1], index, 0, reduce="mean")) # aggregate on fc edges
for aggregate however, I need to slice the tensor inputs[0] to make it match to orig_index, as the length of inputs[0] (as this was created based on the fc_edge_index) is much longer than orig_index, i.e.,
scatter(inputs[0][:len(orig_index)], orig_index, 0, reduce=self.aggr)
The code currently runs without errors, but I'd like to add some more tests, to make sure the aggregations are done as intended. I'll ping you, once I made the PR.
Best regards, Tuan
Hi, Just a question about the example:
node_in_channels = 16
edge_in_channels = 8
pos_in_channels = 3
local_nn = Linear(2 * node_in_channels + 1 + edge_in_channels, node_in_channels, bias=False)
pos_nn = Linear(node_in_channels, 1, bias=True)
global_nn = Linear(2 * node_in_channels, node_in_channels, bias=True)
conv = EquivariantConv(local_nn=local_nn, pos_nn=pos_nn, global_nn=global_nn)
x_out, (pos_out, _) = conv(x=x, pos=pos, edge_index=edge_index, fc_edge_index=fc_edge_index, edge_attr=fc_edge_attr)
I want to be sure that the first pos_nn argument is well node_in_channels and not pos_in_channels.
Best regards Kevin
@wsad1 What is the status of this PR? I'd be happy to help if there is still work to be done.