equiformer_v2 icon indicating copy to clipboard operation
equiformer_v2 copied to clipboard

Incorporating vector node features

Open patriksimurka opened this issue 11 months ago • 1 comments

Hi, nice work! I was wondering what it would take to accommodate for systems with nodes that have additional non-scalar features. Any hints or snippets would be greatly appreciated. Thanks.

patriksimurka avatar Mar 26 '24 12:03 patriksimurka

Hi @patriksimurka

The node embeddings have the shape (num_nodes, (1 + L_{max}) ** 2, num_channels). To add scalars to the node embeddings, we can first expand the scalar properties to have num_channels channels, and then add those scalar vectors to the type-0 part:

num_channels = 128
lmax = 6
num_nodes = 128
node_embedding = torch.randn(num_nodes, (1 + lmax) ** 2, num_channels)
input_scalar = torch.randn(num_nodes, 1, num_channels)
node_embedding[:, 0:1, :] = node_embedding[:, 0:1, :] + input_scalar

For type-1 vectors, you can do the following:

type_1_vectors = torch.randn(num_nodes, 3, num_channels)
node_embedding[:, 1:4, :] = node_embedding[:, 1:4, :] + type_1_vectors

Basically, the second dimension of node_embedding in the above examples corresponds to [one type-0 vector, one type-1 vector, ... one type-lmax vector]. So you can just add the type-L vector to the corresponding slice.

yilunliao avatar Mar 27 '24 11:03 yilunliao