deep_gcns_torch icon indicating copy to clipboard operation
deep_gcns_torch copied to clipboard

Dilated layer takes more than `k` neightbours

Open zademn opened this issue 2 years ago • 5 comments

The Dilated layer doesn't take into account k. This can lead to taking more neighbours than intended.

t = torch.tensor([
    [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
    [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5]
])

res = Dilated(k=2, dilation=2)(t)
print(res) # here 3 neighbours are taken even though the constructor specified 2.
# tensor([[0, 0, 0, 1, 1, 1],
#         [0, 2, 4, 0, 2, 4]])

zademn avatar Jun 05 '22 13:06 zademn

The Dilated class will only dilate the provided knn edge_index. You would first need to find the knn graph and then dilate it like DilatedKnnGraph. The Dilated class itself wouldn’t build the knn graph. Sorry for the confusion.

class DilatedKnnGraph(nn.Module):
    """
    Find the neighbors' indices based on dilated knn
    """
    def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0, knn='matrix'):
        super(DilatedKnnGraph, self).__init__()
        self.dilation = dilation
        self.stochastic = stochastic
        self.epsilon = epsilon
        self.k = k
        self._dilated = Dilated(k, dilation, stochastic, epsilon)
        if knn == 'matrix':
            self.knn = knn_graph_matrix
        else:
            self.knn = knn_graph

    def forward(self, x, batch):
        edge_index = self.knn(x, self.k * self.dilation, batch)
        return self._dilated(edge_index, batch)

lightaime avatar Jun 05 '22 14:06 lightaime

I'm not even sure taking [::d] is the right way to go. The following example

t = torch.tensor(
    [
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
        [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
    ]
)
res = Dilated(k=2, dilation=2)(t)
print(res)
# tensor([[0, 0, 0, 1, 1, 2, 2, 2],
#         [0, 2, 4, 1, 3, 0, 2, 4]])

For node 1 [1, 0] and [1, 1] are the expected edges but we get [1, 1] and [1 3]

zademn avatar Jun 05 '22 14:06 zademn

We should first build a knn graph that has k*d neighbors for each nodes then use [::d] to get the dilated graphs. edge_index = self.knn(x, self.k * self.dilation, batch) So this case won’t happen. But you are right. It is not ideal if the provided graph doesn’t have k*d neighbors for each nodes.

lightaime avatar Jun 05 '22 14:06 lightaime

A possible solution would be (using einops):

from einops import rearrange

t = torch.tensor(
    [
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
        [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
    ]
)
k = 2
d = 2

u, counts = torch.unique(t[0], return_counts=True)
k_constructed = counts[0]  # assume we always find k neighbours. We can give this as a parameter too
res1 = rearrange(t, "e (n2 k_constructed) -> e n2 k_constructed", k_constructed=k_constructed)

# tensor([[[0, 0, 0, 0, 0],
#          [1, 1, 1, 1, 1],
#          [2, 2, 2, 2, 2]],

#         [[0, 1, 2, 3, 4],
#          [0, 1, 2, 3, 4],
#          [0, 1, 2, 3, 4]]])
res2 = res1[:, :, ::d]  # Res dilated
print(res2)
# tensor([[[0, 0, 0],
#          [1, 1, 1],
#          [2, 2, 2]],

#         [[0, 2, 4],
#          [0, 2, 4],
#          [0, 2, 4]]])
res3 = res2[:, :, :k] # Take first k neighbours
print(res3)
# tensor([[[0, 0],
#          [1, 1],
#          [2, 2]],

#         [[0, 2],
#          [0, 2],
#          [0, 2]]])
res4 = rearrange(res3, "e d1 d2 -> e (d1 d2)")
print(res4)
# tensor([[0, 0, 1, 1, 2, 2],
#         [0, 2, 0, 2, 0, 2]])

zademn avatar Jun 05 '22 14:06 zademn

Thanks for the suggestion @zademn. That is definitely a good idea if we are dealing with a more complex case. But in our example, we always build knn graphs with k*d neighbors. To keep it simple, we prefer to leave it as it is.

lightaime avatar Jun 05 '22 14:06 lightaime