GraphNeuralNetworks.jl
GraphNeuralNetworks.jl copied to clipboard
Add EGNNConv support for HeteroGraphConv
Covers Issue #311
This is a work in progress, just wanted to make sure I am on the right track
Since EGNNConv
has H
as input as well I added another function:
function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::NamedTuple, h::AbstractMatrix)
function forw(l, et)
sg = edge_type_subgraph(g, et)
node1_t, _, node2_t = et
x_features = (x[node1_t], x[node2_t])
h_features = h # temporary
return l(sg, h_features, x_features)
end
outs = [forw(l, et) for (l, et) in zip(hgc.layers, hgc.etypes)]
dst_ntypes = [et[3] for et in hgc.etypes]
return _reduceby_node_t(hgc.aggr, outs, dst_ntypes)
end
Let me know if there is an alternative like using the arg in the old function (pass as a Dict
) but this just seemed more convenient.
Will add more updates and test in the coming days. Will remove all debug statements when done.