cugraph
cugraph copied to clipboard
[FEA]: clarify cugraph-equivariant's src vs dst semantics in weight-sharing bidirectionaly message passing
Is this a new feature, an improvement, or a change to existing functionality?
Improvement
How would you describe the priority of this feature request
High
Please provide a clear description of problem this feature solves
In cugraph-equivariant's FullyConnectedTensorProductConv
layer (FCTPConv), the 1st layer of the underlying MLP is decomposed into edge embedding, src and dst scalar weight blocks in that order (see code here). In one of our applications, DiffDock-PP, the same FCTPConv layer is called bidirectionally, sharing the same set of MLP weights (see the original code here for one direction and here for another.) Note that these links point to the code that hasn't been replaced with FCTPConv but imagine that situation once they are using FCTPConv: one of the two directions would have to violate the src vs dst assignment if the two directions keep sharing the same set of weights.
The crux here is cugraph-equivariant's FCTPConv layer is inherently single-directional even though the user could shoehorn it into a bidirectionaly usage case by paying attentions to how the underlying MLP computation is carried out, e.g.:
# first direction
a2b_edge_emb = torch.hstack([edge_emb, a_scalars, b_scalars])
cross_conv(a_node_attrs, edge_sh, a2b_edge_emb, (a2b_edge_index, (size_a, size_b)))
# second direction sharing the same weight as the first direction
b2a_edge_index = a2b_edge_index.flip(dims=(0,))
cross_conv(b_node_attrs, edge_sh, a2b_edg_emb, (b2a_edge_index, (size_b, size_a)))
where the semantic violation can be seen a2b_edg_emb
being used in the case of a b->a
convolution
Describe your ideal solution
Not sure if there is any
Describe any alternatives you have considered
A not-so-ideal solution would be to clarify in the documentation how the underlying MLP is being decomposed into the blocks and how src_scalars
and dst_scalars
are being indexed and broadcasted using the edge_index[0]
and edge_index[1]
respectively to at least warn the user of potential semantic violation
Additional context
No response
Code of Conduct
- [X] I agree to follow cuGraph's Code of Conduct
- [X] I have searched the open feature requests and have found no duplicates for this feature request