torch-dftd
torch-dftd copied to clipboard
Add Matscipy Neighborlist Support
For my test case, this implementation was slightly faster than the pymatgen neighborlist but I did not extensively test this.
from matscipy.neighbours import neighbour_list
def calc_neighbor_by_matscipy(
pos: Tensor, cell: Tensor, pbc: Tensor, cutoff: float
) -> Tuple[Tensor, Tensor]:
idx_i, idx_j, S = neighbour_list(
quantities="ijS",
pbc=pbc.detach().cpu().numpy(),
cell=cell.detach().cpu().numpy(),
positions=pos.detach().cpu().numpy(),
cutoff=cutoff,
)
edge_index = torch.tensor(np.stack([idx_i, idx_j], axis=0), dtype=torch.int64, device=pos.device)
# convert int64 -> pos.dtype (float)
S = torch.tensor(S, dtype=pos.dtype, device=pos.device)
return edge_index, S