equiformer_v2 icon indicating copy to clipboard operation
equiformer_v2 copied to clipboard

e3nn tensors compatibility issue

Open liyy2 opened this issue 11 months ago • 3 comments

Hi, I am trying to integrate this with the e3nn package.

For the SO3Embedding class, how can I convert that to an irrep which is compatible with the convention e3nn? My implementation (not sure this is right or not)

    def to_e3nn_embeddings(self):
        from e3nn.io import SphericalTensor
        from e3nn.o3 import Irreps
        embedding = self.embedding.reshape(self.length, -1)

        l = o3.Irreps(str(SphericalTensor(self.lmax_list[-1], 1, -1)).replace('1x', f'{self.num_channels}x'))
        # multiple channels
        return l, embedding

liyy2 avatar Mar 07 '24 20:03 liyy2

Hi @liyy2

I am not familiar with SphericalTensor.

But for tensors in e3nn, they are typically in the form of C_0x0e+C_1x1e... (e.g., 128x0e+128x1e+...). (Let me know if the above one is not clear.)

For EquiformerV2, the tensors are in the form of (0e+1e+..., C) and have shape ((1+L_{max})**2, C). We require the number of channels for each degree to be the same here. (Let me know if that is not clear)

So to convert between these two formats, we can extract all the channels for each degree, flatten them and concatenate all the flattened tensors. Here is an example of converting e3nn tensors to tensors in EquiformerV2:

lmax = 2
num_channels = 128
irreps = o3.Irreps('128x0e+128x1e+128x2e')
tensor_e3nn = irreps.randn(1, -1)  # shape: (1, 128 * (1 + 2) ** 2)

out = []
start_idx = 0
for l in range(lmax + 1):
    length = (2 * l + 1) * num_channels
    feature = tensor_e3nn.narrow(1, start_idx, length)  # extract all the channels corresponding to degree l
    feature = feature.view(-1, num_channels, (2 * l + 1))
    feature = feature.transpose(1, 2).contiguous()
    out.append(feature)
    start_idx = start_idx + length
tensor_equiformer_v2 = torch.cat(out, dim=1)

You can follow the above example to do the reverse.

yilunliao avatar Mar 27 '24 11:03 yilunliao

hi, thank you for the detailed response. My question is does parity impact the model here? Should i use o3.Irreps('128x0e+128x1e+128x2e') or o3.Irreps('128x0e+128x1o+128x2e')

liyy2 avatar Mar 27 '24 19:03 liyy2

For EquiformerV2, we currently use SE(3), and therefore, we should use '128x0e+128x1e+128x2e'.

yilunliao avatar Mar 27 '24 20:03 yilunliao