cugraph icon indicating copy to clipboard operation
cugraph copied to clipboard

[FEA]: clarify cugraph-equivariant's src vs dst semantics in weight-sharing bidirectionaly message passing

Open DejunL opened this issue 8 months ago • 4 comments

Is this a new feature, an improvement, or a change to existing functionality?

Improvement

How would you describe the priority of this feature request

High

Please provide a clear description of problem this feature solves

In cugraph-equivariant's FullyConnectedTensorProductConv layer (FCTPConv), the 1st layer of the underlying MLP is decomposed into edge embedding, src and dst scalar weight blocks in that order (see code here). In one of our applications, DiffDock-PP, the same FCTPConv layer is called bidirectionally, sharing the same set of MLP weights (see the original code here for one direction and here for another.) Note that these links point to the code that hasn't been replaced with FCTPConv but imagine that situation once they are using FCTPConv: one of the two directions would have to violate the src vs dst assignment if the two directions keep sharing the same set of weights.

The crux here is cugraph-equivariant's FCTPConv layer is inherently single-directional even though the user could shoehorn it into a bidirectionaly usage case by paying attentions to how the underlying MLP computation is carried out, e.g.:

# first direction
a2b_edge_emb = torch.hstack([edge_emb, a_scalars, b_scalars])
cross_conv(a_node_attrs, edge_sh, a2b_edge_emb, (a2b_edge_index, (size_a, size_b)))

# second direction sharing the same weight as the first direction
b2a_edge_index = a2b_edge_index.flip(dims=(0,))
cross_conv(b_node_attrs, edge_sh, a2b_edg_emb, (b2a_edge_index, (size_b, size_a)))

where the semantic violation can be seen a2b_edg_emb being used in the case of a b->a convolution

Describe your ideal solution

Not sure if there is any

Describe any alternatives you have considered

A not-so-ideal solution would be to clarify in the documentation how the underlying MLP is being decomposed into the blocks and how src_scalars and dst_scalars are being indexed and broadcasted using the edge_index[0] and edge_index[1] respectively to at least warn the user of potential semantic violation

Additional context

No response

Code of Conduct

  • [X] I agree to follow cuGraph's Code of Conduct
  • [X] I have searched the open feature requests and have found no duplicates for this feature request

DejunL avatar Jun 03 '24 16:06 DejunL