deep_gcns_torch
deep_gcns_torch copied to clipboard
Dilated layer takes more than `k` neightbours
The Dilated
layer doesn't take into account k
. This can lead to taking more neighbours than intended.
t = torch.tensor([
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5]
])
res = Dilated(k=2, dilation=2)(t)
print(res) # here 3 neighbours are taken even though the constructor specified 2.
# tensor([[0, 0, 0, 1, 1, 1],
# [0, 2, 4, 0, 2, 4]])
The Dilated
class will only dilate the provided knn edge_index
. You would first need to find the knn graph and then dilate it like DilatedKnnGraph. The Dilated
class itself wouldn’t build the knn graph. Sorry for the confusion.
class DilatedKnnGraph(nn.Module):
"""
Find the neighbors' indices based on dilated knn
"""
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0, knn='matrix'):
super(DilatedKnnGraph, self).__init__()
self.dilation = dilation
self.stochastic = stochastic
self.epsilon = epsilon
self.k = k
self._dilated = Dilated(k, dilation, stochastic, epsilon)
if knn == 'matrix':
self.knn = knn_graph_matrix
else:
self.knn = knn_graph
def forward(self, x, batch):
edge_index = self.knn(x, self.k * self.dilation, batch)
return self._dilated(edge_index, batch)
I'm not even sure taking [::d]
is the right way to go. The following example
t = torch.tensor(
[
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
]
)
res = Dilated(k=2, dilation=2)(t)
print(res)
# tensor([[0, 0, 0, 1, 1, 2, 2, 2],
# [0, 2, 4, 1, 3, 0, 2, 4]])
For node 1 [1, 0]
and [1, 1]
are the expected edges but we get [1, 1]
and [1 3]
We should first build a knn graph that has k*d
neighbors for each nodes then use [::d]
to get the dilated graphs.
edge_index = self.knn(x, self.k * self.dilation, batch)
So this case won’t happen. But you are right. It is not ideal if the provided graph doesn’t have k*d
neighbors for each nodes.
A possible solution would be (using einops):
from einops import rearrange
t = torch.tensor(
[
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
]
)
k = 2
d = 2
u, counts = torch.unique(t[0], return_counts=True)
k_constructed = counts[0] # assume we always find k neighbours. We can give this as a parameter too
res1 = rearrange(t, "e (n2 k_constructed) -> e n2 k_constructed", k_constructed=k_constructed)
# tensor([[[0, 0, 0, 0, 0],
# [1, 1, 1, 1, 1],
# [2, 2, 2, 2, 2]],
# [[0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4]]])
res2 = res1[:, :, ::d] # Res dilated
print(res2)
# tensor([[[0, 0, 0],
# [1, 1, 1],
# [2, 2, 2]],
# [[0, 2, 4],
# [0, 2, 4],
# [0, 2, 4]]])
res3 = res2[:, :, :k] # Take first k neighbours
print(res3)
# tensor([[[0, 0],
# [1, 1],
# [2, 2]],
# [[0, 2],
# [0, 2],
# [0, 2]]])
res4 = rearrange(res3, "e d1 d2 -> e (d1 d2)")
print(res4)
# tensor([[0, 0, 1, 1, 2, 2],
# [0, 2, 0, 2, 0, 2]])
Thanks for the suggestion @zademn. That is definitely a good idea if we are dealing with a more complex case. But in our example, we always build knn graphs with k*d
neighbors. To keep it simple, we prefer to leave it as it is.