deep-graph-matching-consensus icon indicating copy to clipboard operation
deep-graph-matching-consensus copied to clipboard

Sorting requirement for __include_gt__

Open Nifury opened this issue 2 years ago • 0 comments

Hello, I notice that if the ground truth y[0] is not sorted, __include_gt__ does not behave properly. It might be worth mentioning this in the documentation.

Code to reproduce:

def test(self):
    h_s = torch.randn(1, 10, 20)
    h_t = torch.randn(1, 10, 20)
    s_mask = torch.ones(1, 10, dtype=torch.bool)
    y = torch.as_tensor([[2, 0, 1], [3, 4, 5]])
    # make sure top k doesn't include ground truth
    h_s[0, y[0]] = 100
    h_t[0, y[1]] = -100
    self.k = 1
    S_idx = self.__top_k__(h_s, h_t)
    S_rnd_idx = torch.zeros(1, 10, 1, dtype=torch.long)
    S_idx = torch.cat([S_idx, S_rnd_idx], dim=-1)
    S_idx = self.__include_gt__(S_idx, s_mask, y)
    mask = S_idx[0, y[0]] == y[1].view(-1, 1)
    print(mask.any(dim=-1))

Nifury avatar Oct 13 '23 02:10 Nifury