pytorch_sparse icon indicating copy to clipboard operation
pytorch_sparse copied to clipboard

Neighbor Sampling without replacement doesn't sample uniformly

Open aristizabal95 opened this issue 1 year ago • 2 comments
trafficstars

Exploring the behavior of pytorch geometric's LinkNeighborSampler, we identified the sampling method consistently oversampled the first possible combination of neighbors, while consistently missed the last possible combination of neighbors. This error appears to happen due to the sampling logic implemented here.

How to reproduce

We tested with a toy example of a heterogeneous graph.

from torch_geometric.data import HeteroData
from torch_geometric.transforms import ToUndirected
import torch

num_papers = 4
num_paper_features = 10
num_authors = 5
num_authors_features = 12
data = HeteroData()

# Create two node types "paper" and "author" holding a feature matrix:
data['paper'].x = torch.randn(num_papers, num_paper_features)
data['paper'].id = torch.arange(0, num_papers)
data['paper'].type = "paper"
data['author'].x = torch.randn(num_authors, num_authors_features)
data['author'].id = torch.arange(0, num_authors)
data['author'].type = "author"

# Create an edge type "(author, writes, paper)" and building the
# graph connectivity:
data['author', 'writes', 'paper'].edge_index = torch.tensor([
    [0, 1, 2, 2, 3, 4, 2, 2],
    [0, 0, 0, 1, 0, 1, 2, 3],
])


# PyTorch tensor functionality:
transform = ToUndirected()
data = transform(data)

With this toy graph, we created a LinkNeighborLoader that samples two authors

from torch_geometric.loader import LinkNeighborLoader

num_neighbors = {
    ('author', 'writes', 'paper'): [2],
    ('paper', 'rev_writes', 'author'): [0],
}

edge_label_index = (('author', 'writes', 'paper'), torch.tensor([[0],[0]]))
edge_label = torch.tensor([[1, 1]])

loader = LinkNeighborLoader(
    data,
    num_neighbors=num_neighbors,
    edge_label_index=edge_label_index,
    edge_label=edge_label,
)

Lastly, to determine the sampling frecuency of pairs of neighbors, we sampled 10000 pairs and counted pair frequency

from collections import defaultdict

counts = defaultdict(lambda: 0)
num_samples = 10000

for _ in range(num_samples):

    sampled_data = next(iter(loader))
    edge = sampled_data[('author', 'writes', 'paper')].edge_index[0].tolist()
    author_id = sampled_data['author'].n_id[edge[0]].item()
    author2_id = sampled_data['author'].n_id[edge[1]].item()
    key = tuple(sorted((author_id, author2_id)))
    counts[key] += 1

The output consistently displayed the first combination being sampled two times more than any other, while the last possible combination was never sampled

for k, v in dict(counts).items():
    print(k, v/num_samples)
# > (0, 1) 0.3382
# > (1, 2) 0.164
# > (0, 2) 0.1671
# > (0, 3) 0.1706
# > (1, 3) 0.1601

In this case, the combination (0,1) was oversampled, showing up twice as more than any other combination, while the combination (2,3) was not sampled at all.

this happened with every number of neighbors and root nodes we selected, as long as the number of sampled neighbors was less than the number of neighbors.

aristizabal95 avatar Oct 01 '24 22:10 aristizabal95