pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

Update negative sampling to work directly on GPU

Open danielecastellana22 opened this issue 1 year ago • 4 comments

@rusty1s I re-implemented the negative edge sampling in order to work directly on GPUs. In the following, I summarise the main idea of my implementation. Then, I will rise same questions that I hope you can help me to answer.

Negative Sampling

The idea of the negative edge sampling is to obtain a list of edges, and then discard all the ones in the input graph. To perform the existence check, I use torch.searchsorted. The input edge_index should be sorted, but this is usually the case.

The initial guess of negative edges can be obtained in a dense and sparse way. The function also supports the automatic way to let the code automatically choose the best method for the input graph.

Dense Method

The dense method is exact (you can always obtain the desired number of negative edge samples if it is possible), but it is costly since it enumerates all the possible edges (so the cost is quadratic w.r.t the number of nodes). The samples are obtained through the function torch.randperm to get a stochastic process.

Sparse Method

The sparse method is not exact (you could obtain fewer samples than the requested number), but it is more efficient since it does not enumerate all possible edges. To obtain the guess, we simply sample k edges using torch.multinomial. The number k is crucial to obtain the desired number of negative edges, and it depends on the probability to sample a negative edge randomly.

Structured Negative Sampling

In a similar way, I implemented the structured negative sampling. The main difference here is that we would like to sample a negative edge $(u,w)$ for each edge $(u,v)$ in the graph.

Dense Method

For each node $u$, we obtain a random permutation of all the nodes. Then, we select the first $deg(u)$ nodes that are not linked with $u$.

Sparse Method

We sample $k*E$ edges, where $E$ is the number of edges in the input graph. Here the choice of $k$ is more tricky since it depends on the degree of each node. When the method is not able to obtain a negative sample for an edge, it returns $-1$.

Open Questions

  1. Sorting the edges I use the pyg function index_sort to sort the input edge_index. However, I believe that in most of the cases, the input is already sorted. Hence, another option could be assuming that the input edge_index is already sorted. In this way, the sort becomes a duty of the user.

  2. How to manage the warning It could be cool to raise a warning when the sparse method could fail. I raise some warning when the probability of sampling is low, but probably it is a better idea to raise the warning when we are sure that the method has failed (e.g. the number of sampled edges is 10% less than the requested number)

  3. Determine the number of samples in the structured negative sampling I am struggling to find a way to determine the number $k$ in the sparse structured negative sampling. Probably my approach is not the best one since it is "global", hence it is difficult to find the right $k$. For now, I just set $k=10$.

  4. Feasibility of structured negative sampling If I understood the code of structured_negative_sampling_feasible, it returns true if there are no nodes that are connected with all the others. I think this is wrong since the structured negative sampling should be feasible if and only if $deg(u) < N/2$ for all nodes in the graph.

Let me know what you think about this!

danielecastellana22 avatar Aug 19 '24 17:08 danielecastellana22

wow this would be awesome!!

denadai2 avatar Aug 20 '24 11:08 denadai2

Wow, this is pretty cool. Can we fix the tests so that we are Python 3.8 compatible (no Python 3.10 type hints such as str | int)?

rusty1s avatar Aug 22 '24 02:08 rusty1s

Thank you, I am happy to hear this is a desired feature!

I updated the code to support python 3.8, and made small changes to the negative sampling test functions. On my laptop, all the test cases affected by the changes are passed: it looks like other tests failed.

danielecastellana22 avatar Aug 22 '24 17:08 danielecastellana22

Hello, I updated the code to pass all the pytests. However, pre-commit.ci raises a PEP8 error that I cannot solve (it is the only required checks that fail). @wsad1 can you help me?

danielecastellana22 avatar Sep 05 '24 12:09 danielecastellana22

@wsad1 @rusty1s is there any news? Can I do anything to help you?

danielecastellana22 avatar Mar 05 '25 16:03 danielecastellana22