GraphNeuralNetworks.jl icon indicating copy to clipboard operation
GraphNeuralNetworks.jl copied to clipboard

Add EGNNConv support for HeteroGraphConv

Open rbSparky opened this issue 1 year ago • 3 comments

Covers Issue #311

This is a work in progress, just wanted to make sure I am on the right track

Since EGNNConv has H as input as well I added another function:

function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::NamedTuple, h::AbstractMatrix)
    function forw(l, et)
        sg = edge_type_subgraph(g, et)
        node1_t, _, node2_t = et

        x_features = (x[node1_t], x[node2_t])
        h_features = h # temporary

        return l(sg, h_features, x_features)
    end
    outs = [forw(l, et) for (l, et) in zip(hgc.layers, hgc.etypes)]
    dst_ntypes = [et[3] for et in hgc.etypes]
    return _reduceby_node_t(hgc.aggr, outs, dst_ntypes)
end

Let me know if there is an alternative like using the arg in the old function (pass as a Dict) but this just seemed more convenient.

Will add more updates and test in the coming days. Will remove all debug statements when done.

rbSparky avatar Feb 22 '24 19:02 rbSparky