GraphNeuralNetworks.jl icon indicating copy to clipboard operation
GraphNeuralNetworks.jl copied to clipboard

First draft GPS conv layer

Open abieler opened this issue 2 years ago • 2 comments

This is only first "mock" version of a GPSConv layer to see if we would want it in the Repo in that form.

  • Adds a DotProductAttention layer that uses NNlib.dot_product_attention()

  • Adds a GPSConv layer

    • has the DotPRoductAttention as global attention layer
    • takes a conv-layer as local message passing
  • Not sure about the GNNChain() implementation, if it should stay where it is or move into the struct?

  • JuliaFormatter() got a bit too greedy and made some changes here and there, I can revert those of course

  • Did not check for correctness of the implementation yet

Let me know what you think and I can adjust / keep going from here.

Close #351

abieler avatar Dec 28 '23 12:12 abieler

Thanks, this is a nice start. A few comments:

  • No need to introduce the DotProductAttention type, we can use the MultiHeadAttention from Flux. According to the table A.2-A.5 in the paper, multi-head attention is the preferred choice. We should have an nheads argument in the constructor. See also https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GPSConv.html for the kind of flexibility we could try to achieve (in several PRs).

  • Part of the contribution of the paper is the discussion of different types of embeddings. This package lacks many of these embeddings. I hope they will be added in the future but in any case, it is ok for this PR to only implement the layer.

  • I think the current order of operations is wrong, see comment https://github.com/CarloLucibello/GraphNeuralNetworks.jl/pull/355#discussion_r1438471782

  • BathcNorm should be used instead of LayerNorm

  • In the paper is not clear if we should apply a residual connection after the MLP. For figure D.1 it seems there is one, but there is none according to Eq. 11.

CarloLucibello avatar Dec 30 '23 04:12 CarloLucibello

Thanks for the comments. I'll be going over the authors codebase > paper > pytorch implementation for implementation details for the next version

i.e.

abieler avatar Jan 02 '24 13:01 abieler