pytorch_geometric
pytorch_geometric copied to clipboard
radius_graph for periodic structure
❓ Questions & Help
Do you have a radius_graph
function for periodic structure?
Otherwise I will think how I can implement one and propose a PR, any suggestions?
So far we are using pymatgen.core.structure.Structure.get_all_neighbors
to do that.
Can you give some pointers why the (standard) radius_graph
procedure is not applicable for those structures? What needs to be changed in order to achieve the desired output?
The following code (not tested) is inefficient and works only if the cutoff
is small enough.
lattice = torch.randn(3, 3) # 3 vectors forming a parallelogram in 3d
pos = torch.randn(n, 3) # usually inside the cell
pos = torch.stack([
pos + i * lattice[0] + j * lattice[1] + k * lattice[2]
for i in [-1, 0, 1]
for j in [-1, 0, 1]
for k in [-1, 0, 1]
])
index = radius_graph(pos, cutoff)
index = index % n
This would be nice to have. I know the folks over at https://github.com/Open-Catalyst-Project have an example, in https://github.com/Open-Catalyst-Project/ocp/blob/master/ocpmodels/common/utils.py. I haven't had the chance to implement my own.
Thanks for the pointers. It looks like OpenCatalyst provides a dense version of it. I'm open to integrating a cuda version of this in torch-cluster
. @mariogeiger Please let me know if you have any plans to contribute.
Yes I will try to make something. I plan to look at it this week or maybe the week after. I'd love to contribute
I would like to implement a SchNet nn with a periodic cell, and thus have the radius_graph with periodic boundary conditions. Is it possible to implement these periodic boundary conditions?
@mariogeiger @keano130 Did you make any progress?
As I have little to no experience in working with cuda (or writing code in C), I am afraid that I can not write (efficient) code to implement this. Ideally something along the lines of https://manual.gromacs.org/2021/reference-manual/algorithms/molecular-dynamics.html#simple-search could be efficiently implemented, this assume that the cell matrix is a lower triangle matrix, which can be achieved through simple transformations (rotations). Otherwise I will try and implement an inefficient python script instead.
@keano130 Why is ase implementation not good enough for your usage?
In order to use the model in openmm, i need to convert my model to torchscript, which is not possible with the ase neighborlist. Furthermore, the ase implementation is also less efficient.
SchNet has been previously implemented in https://github.com/openmm/NNPOps/tree/master/schnet in C++. In https://github.com/openmm/NNPOps/blob/master/schnet/CudaCFConv.cu, the algorithm which I linked earlier was used in the computeDisplacement function written for cuda. This might help in adapting the code to periodic boundary conditions. Yet there is still too much I don't understand in the code for me to adapt the code myself.
Any update on this issue? I would also like to create a radius graph that takes into account periodic boundary conditions. It is possible to do so by several times shifting the original positions and remapping them to the original domain, and computing the edges for each shift.
More elegant, and perhaps with unrelated applications as well, would be if one can provide a custom difference function to radius_graph
. For example below is a difference for a box with periodicity in all directions.
I don't know if that's feasible though as I don't know how radius_graph
works internally, and found it very difficult to track down, as there are several imports between torch geometric and torch_cluster and vice versa.
def periodic_difference(x_i, x_j, domain):
"""
Compute x_i - x_j taking into account the periodic boundary conditions.
"""
diff = x_i - x_j
smaller_one = x_i < x_j # component-wise check which is bigger
domain_shift = (1 - 2 * smaller_one) * domain
diff_shifted = diff - domain_shift
# boolean indicating in which component to use the original difference
use_original = torch.abs(diff) < torch.abs(diff_shifted)
periodic_diff = use_original * diff + ~use_original * diff_shifted
return periodic_diff
No progress on this issue? I hope to find efficient ways for it.
import torch
from torch_cluster import radius
from torch_geometric.nn import radius_graph
# no batch, no loop, only source_to_target
def material_radius_graph(pos, lattice, cutoff):
if torch.inf in lattice:
edge_index = radius_graph(pos, cutoff, batch=None, loop=False, max_num_neighbors=9999, flow='source_to_target')
R = torch.zeros((edge_index.size(1), 3), device=pos.device)
return edge_index, R
else:
assert lattice.norm(dim=-1).max() < 1e4 # due to the precision problem
lattice = lattice.squeeze()
vol = lattice.det()
area = torch.cross(
lattice.roll(shifts=1, dims=0), lattice.roll(shifts=2, dims=0), dim=1
).norm(dim=1)
height = vol / area
# to consider atoms out of the lattice
extra_R = (pos @ lattice.inverse()).floor_divide(1.0)
bound = (cutoff / height).ceil()
l, m, n = lmn = (-bound + extra_R.min())
L, M, N = LMN = (bound + extra_R.max() + 1.0) # plus 1 due to the boundary [,) in torch.arange below
grid_l = torch.arange(l.item(), L.item(), device=pos.device)
grid_m = torch.arange(m.item(), M.item(), device=pos.device)
grid_n = torch.arange(n.item(), N.item(), device=pos.device)
mesh_lmn = torch.stack(
torch.meshgrid(grid_l, grid_m, grid_n, indexing='ij')
).view(3, -1).transpose(0, 1)
R = mesh_lmn @ lattice
R_pos = (R.unsqueeze(1) + pos.unsqueeze(0)).view(-1, 3) # (num_R, num_pos, 3) -> (num_pos*num_R, 3) not (num_R*num_pos, 3)
row, col = radius(pos, R_pos, cutoff, None, None, max_num_neighbors=9999) # row: R_pos, col: pos
pos_row, pos_col = R_pos[row], pos[col]
row, lmn_row = row.remainder(pos.size(0)), row.floor_divide(pos.size(0))
mask = (row != col) | (pos_row != pos_col).any(dim=1)
row, col, lmn_row = row[mask], col[mask], lmn_row[mask]
edge_index = torch.stack([row, col], dim=0)
R = mesh_lmn[lmn_row] @ lattice
return edge_index, R
Given pos
, lattice
, and cutoff
,
edge_index, R = material_radius_graph(pos, lattice, cutoff):
row, col = edge_index
source_to_target_vector = pos[col] - (pos[row] + R)
The effective method for handling material graphs, including both molecules and crystals, hinges on how the torch.meshgrid(grid_l, grid_m, grid_n, indexing='ij')
and R_pos = (R.unsqueeze(1) + pos.unsqueeze(0)).view(-1, 3)
parts can be improved.
Because of the GPU memory limit, I've concluded that operating on individual data rather than executing batch operations results in better.