pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

radius_graph for periodic structure

Open mariogeiger opened this issue 4 years ago • 13 comments

❓ Questions & Help

Do you have a radius_graph function for periodic structure? Otherwise I will think how I can implement one and propose a PR, any suggestions? So far we are using pymatgen.core.structure.Structure.get_all_neighbors to do that.

mariogeiger avatar Nov 25 '20 17:11 mariogeiger

Can you give some pointers why the (standard) radius_graph procedure is not applicable for those structures? What needs to be changed in order to achieve the desired output?

rusty1s avatar Nov 27 '20 08:11 rusty1s

The following code (not tested) is inefficient and works only if the cutoff is small enough.

lattice = torch.randn(3, 3)  # 3 vectors forming a parallelogram in 3d
pos = torch.randn(n, 3)  # usually inside the cell

pos = torch.stack([
    pos + i * lattice[0] + j * lattice[1] + k * lattice[2]
    for i in [-1, 0, 1]
    for j in [-1, 0, 1]
    for k in [-1, 0, 1]
])

index = radius_graph(pos, cutoff)
index = index % n

mariogeiger avatar Nov 27 '20 09:11 mariogeiger

This would be nice to have. I know the folks over at https://github.com/Open-Catalyst-Project have an example, in https://github.com/Open-Catalyst-Project/ocp/blob/master/ocpmodels/common/utils.py. I haven't had the chance to implement my own.

vxfung avatar Nov 30 '20 04:11 vxfung

Thanks for the pointers. It looks like OpenCatalyst provides a dense version of it. I'm open to integrating a cuda version of this in torch-cluster. @mariogeiger Please let me know if you have any plans to contribute.

rusty1s avatar Nov 30 '20 07:11 rusty1s

Yes I will try to make something. I plan to look at it this week or maybe the week after. I'd love to contribute

mariogeiger avatar Nov 30 '20 08:11 mariogeiger

I would like to implement a SchNet nn with a periodic cell, and thus have the radius_graph with periodic boundary conditions. Is it possible to implement these periodic boundary conditions?

keano130 avatar Aug 06 '21 09:08 keano130

@mariogeiger @keano130 Did you make any progress?

rusty1s avatar Aug 09 '21 08:08 rusty1s

No we give up the idea of implementing it in cuda. We use ase

mariogeiger avatar Aug 09 '21 08:08 mariogeiger

As I have little to no experience in working with cuda (or writing code in C), I am afraid that I can not write (efficient) code to implement this. Ideally something along the lines of https://manual.gromacs.org/2021/reference-manual/algorithms/molecular-dynamics.html#simple-search could be efficiently implemented, this assume that the cell matrix is a lower triangle matrix, which can be achieved through simple transformations (rotations). Otherwise I will try and implement an inefficient python script instead.

keano130 avatar Aug 09 '21 09:08 keano130

@keano130 Why is ase implementation not good enough for your usage?

mariogeiger avatar Aug 09 '21 09:08 mariogeiger

In order to use the model in openmm, i need to convert my model to torchscript, which is not possible with the ase neighborlist. Furthermore, the ase implementation is also less efficient.

keano130 avatar Aug 09 '21 11:08 keano130

SchNet has been previously implemented in https://github.com/openmm/NNPOps/tree/master/schnet in C++. In https://github.com/openmm/NNPOps/blob/master/schnet/CudaCFConv.cu, the algorithm which I linked earlier was used in the computeDisplacement function written for cuda. This might help in adapting the code to periodic boundary conditions. Yet there is still too much I don't understand in the code for me to adapt the code myself.

keano130 avatar Aug 09 '21 15:08 keano130

Any update on this issue? I would also like to create a radius graph that takes into account periodic boundary conditions. It is possible to do so by several times shifting the original positions and remapping them to the original domain, and computing the edges for each shift.

More elegant, and perhaps with unrelated applications as well, would be if one can provide a custom difference function to radius_graph. For example below is a difference for a box with periodicity in all directions.

I don't know if that's feasible though as I don't know how radius_graph works internally, and found it very difficult to track down, as there are several imports between torch geometric and torch_cluster and vice versa.

def periodic_difference(x_i, x_j, domain):
    """
    Compute x_i - x_j taking into account the periodic boundary conditions.
    """
    diff = x_i - x_j
    smaller_one = x_i < x_j  # component-wise check which is bigger
    domain_shift = (1 - 2 * smaller_one) * domain
    diff_shifted = diff - domain_shift
    # boolean indicating in which component to use the original difference
    use_original = torch.abs(diff) < torch.abs(diff_shifted)
    periodic_diff = use_original * diff + ~use_original * diff_shifted
    return periodic_diff

APJansen avatar Aug 19 '22 13:08 APJansen

No progress on this issue? I hope to find efficient ways for it.

Nokimann avatar Feb 16 '23 08:02 Nokimann

import torch
from torch_cluster import radius
from torch_geometric.nn import radius_graph

# no batch, no loop, only source_to_target
def material_radius_graph(pos, lattice, cutoff):
    if torch.inf in lattice:
        edge_index = radius_graph(pos, cutoff, batch=None, loop=False, max_num_neighbors=9999, flow='source_to_target')
        R = torch.zeros((edge_index.size(1), 3), device=pos.device)
        return edge_index, R
    else:
        assert lattice.norm(dim=-1).max() < 1e4 # due to the precision problem

        lattice = lattice.squeeze()
        
        vol = lattice.det()
        area = torch.cross(
            lattice.roll(shifts=1, dims=0), lattice.roll(shifts=2, dims=0), dim=1
        ).norm(dim=1)
        height = vol / area
        
        # to consider atoms out of the lattice
        extra_R = (pos @ lattice.inverse()).floor_divide(1.0)
        
        bound = (cutoff / height).ceil()
        l, m, n = lmn = (-bound + extra_R.min())
        L, M, N = LMN = (bound + extra_R.max() + 1.0) # plus 1 due to the boundary [,) in torch.arange below
        
        grid_l = torch.arange(l.item(), L.item(), device=pos.device)
        grid_m = torch.arange(m.item(), M.item(), device=pos.device)
        grid_n = torch.arange(n.item(), N.item(), device=pos.device)
        mesh_lmn = torch.stack(
            torch.meshgrid(grid_l, grid_m, grid_n, indexing='ij')
        ).view(3, -1).transpose(0, 1)

        R = mesh_lmn @ lattice
        R_pos = (R.unsqueeze(1) + pos.unsqueeze(0)).view(-1, 3) # (num_R, num_pos, 3) -> (num_pos*num_R, 3) not (num_R*num_pos, 3)
        
        row, col = radius(pos, R_pos, cutoff, None, None, max_num_neighbors=9999) # row: R_pos, col: pos
        pos_row, pos_col = R_pos[row], pos[col]
        row, lmn_row = row.remainder(pos.size(0)), row.floor_divide(pos.size(0))
        
        mask = (row != col) | (pos_row != pos_col).any(dim=1)
        row, col, lmn_row = row[mask], col[mask], lmn_row[mask]
        
        edge_index = torch.stack([row, col], dim=0)
        R = mesh_lmn[lmn_row] @ lattice
        return edge_index, R

Given pos, lattice, and cutoff,

edge_index, R = material_radius_graph(pos, lattice, cutoff):
row, col = edge_index
source_to_target_vector = pos[col] - (pos[row] + R)

The effective method for handling material graphs, including both molecules and crystals, hinges on how the torch.meshgrid(grid_l, grid_m, grid_n, indexing='ij') and R_pos = (R.unsqueeze(1) + pos.unsqueeze(0)).view(-1, 3) parts can be improved.

Because of the GPU memory limit, I've concluded that operating on individual data rather than executing batch operations results in better.

Nokimann avatar Feb 23 '23 07:02 Nokimann