NNPOps
NNPOps copied to clipboard
Benchmarking SchNet
I've been trying to figure out how to write a benchmarking script for SchNet. Here's what I have so far with SchNetPack. It loads a PDB file and computes the energy 1000 times with one of the pre-trained QM9 models. I haven't figured out yet how to get it to compute forces, so any advice on that would be appreciated. There probably are other ways this could be improved too.
import torch
import schnetpack as spk
import schnetpack.md.calculators
import sys
import ase.io
import time
device = torch.device('cuda')
model = torch.load("trained_schnet_models/qm9_energy_U0/best_model", map_location=device)
atoms = ase.io.read(sys.argv[1])
system = spk.md.System(1, device=device)
system.load_molecules([atoms])
calculator = spk.md.calculators.SchnetPackCalculator(
model,
required_properties=['energy_U0'],
force_handle=spk.Properties.forces,
position_conversion='A',
force_conversion='kcal/mol/A'
)
inputs = calculator._generate_input(system)
model(inputs)
t1 = time.time()
for i in range(1000):
results = model(inputs)
print(results)
print(time.time()-t1)
Testing a 60 atom system on a Titan V, it takes about 3.6 ms per energy evaluation. Testing a 2269 atom system it runs out of memory on the GPU and crashes.
While the test is running, nvidia-smi
shows that the GPU is only 28% busy. nvvp
shows a lot of short kernels with larger gaps between them. The two most significant kernels are volta_sgemm_32x128_tn
(19.8% of GPU time) and volta_sgemm_32x32_sliced1x4_tn
(16% of GPU time). It then gets into a whole lot of kernels with uninformative names like _ZN2at6native6legacy18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_EEvS5_RKT_EUliE2_EEviT1_
.
Testing a 60 atom system on a Titan V, it takes about 3.6 ms per energy evaluation.
The speed of SchNet is comparable to ANI. I thought these graph convolutions would be much slower. So it seems the ultimate bottleneck is the matrix multiplications, which is the case for ANI too.
Any idea how I can get it to compute forces? The documentation implies I should just be able to call calculate()
on the Calculator, but that throws an exception because the property names in this model don't match what it's expecting.
@stefdoerr might know.
Sorry I don't have experience with SchnetPack. Maybe @giadefa knows
Maybe check the available_properties
attribute of the class?
We don't use spk.md.calculator. Check this: https://github.com/torchmd/torchmd-cg/blob/master/torchmd_cg/nnp/calculators/torchmdcalc.py
That code assumes the model has an output called "forces". The pretrained QM9 model only has a single output called "energy_U0".
I think I managed to correctly get it to compute forces, though I don't know if I'm doing it in the best way. After loading the system I added the line
system.positions.requires_grad_()
And then I changed the loop to
for i in range(1000):
if system.positions.grad is not None:
system.positions.grad.zero_()
results = model(inputs)
results['energy_U0'].backward(retain_graph=True)
It now takes 7.7 ms per iteration, which is still quite respectable compared to TorchANI. nvidia-smi
now shows the GPU as 46% busy and nvvp
still shows lots of gaps between kernels, so there ought to be lots of room for speedups.
What labels are available it is set in the constructor
On Fri, 23 Oct 2020, 22:07 peastman, [email protected] wrote:
That code assumes the model has an output called "forces". The pretrained QM9 model only has a single output called "energy_U0".
I think I managed to correctly get it to compute forces, though I don't know if I'm doing it in the best way. After loading the system I added the line
system.positions.requires_grad_()
And then I changed the loop to
for i in range(1000): if system.positions.grad is not None: system.positions.grad.zero_() results = model(inputs) results['energy_U0'].backward(retain_graph=True)
It now takes 7.7 ms per iteration, which is still quite respectable compared to TorchANI. nvidia-smi now shows the GPU as 46% busy and nvvp still shows lots of gaps between kernels, so there ought to be lots of room for speedups.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/peastman/NNPOps/issues/14#issuecomment-715565656, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOVNHDX3BMFMMVEQ7SDSMHO6ZANCNFSM4S2JOIJA .
It appears SchNetPack does support neighbor lists (specified with the System's neighborlist
argument), but the only implementation it provides just lists every atom as being a neighbor of every other. Creating a proper implementation might improve performance, and should also help with running out of memory on large molecules.
They recently tried an implementation from torchani which was a bit faster but also consuming too much memory
On Thu, 29 Oct 2020, 18:57 peastman, [email protected] wrote:
It appears SchNetPack does support neighbor lists (specified with the System's neighborlist argument), but the only implementation it provides just lists every atom as being a neighbor of every other. Creating a proper implementation might improve performance, and should also help with running out of memory on large molecules.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/peastman/NNPOps/issues/14#issuecomment-718923016, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOVZEVXFZGMB2A6Y7NTSNGUI5ANCNFSM4S2JOIJA .
I wrote a proper neighbor list implementation. There's probably a more efficient way of building it, but it works. I can now run the 2269 atom system. It takes 76 ms per iteration. On the 60 atom system, there's no change in speed.
import torch
import schnetpack as spk
class NeighborList(spk.md.neighbor_lists.MDNeighborList):
def __init__(self, system, cutoff=None):
self.simple = spk.md.neighbor_lists.SimpleNeighborList(system, cutoff)
super(NeighborList, self).__init__(system, cutoff)
def _construct_neighbor_list(self):
self.simple._construct_neighbor_list()
neighbors = self.simple.neighbor_list.view(-1, self.system.max_n_atoms, self.system.max_n_atoms-1)
positions = self.system.positions.view(-1, self.system.max_n_atoms, 3)
n_copies = neighbors.shape[0]
n_atoms = neighbors.shape[1]
r_ij = spk.nn.neighbors.atom_distances(positions, neighbors, None)
lists = []
for i in range(n_copies):
for j in range(n_atoms):
lists.append(neighbors[i,j][r_ij[i,j]<self.cutoff])
max_neighbors = max(len(l) for l in lists)
self.neighbor_list = torch.zeros(n_copies, n_atoms, max_neighbors, device=self.system.device, dtype=torch.int64)
self.neighbor_mask = torch.zeros(n_copies, n_atoms, max_neighbors, device=self.system.device)
iterator = iter(lists)
for i in range(n_copies):
for j in range(n_atoms):
l = next(iterator)
size = len(l)
self.neighbor_list[i,j,:size] = l
self.neighbor_mask[i,j,:size] = 1
n_replicas = self.simple.neighbor_list.shape[0]
n_molecules = self.simple.neighbor_list.shape[1]
self.neighbor_list = self.neighbor_list.view(n_replicas, n_molecules, n_atoms, max_neighbors)
self.neighbor_mask = self.neighbor_mask.view(n_replicas, n_molecules, n_atoms, max_neighbors)
def update_neighbors(self):
self._construct_neighbor_list()
When running the large system, the GPU is 99% busy. That's compared to only 43% when running the small system.
Never mind my comments (which I deleted). They were irrelevant to this project, I got confused by another discussion, sorry
I realized there was a mistake in the numbers above: I wasn't rebuilding the neighbor list for each iteration. With the default implementation you don't need to because it just includes every interaction, but if you want to use a real neighbor list you would need to rebuild it for every step of a simulation. The code above is very slow, so I came up with a much faster implementation.
import torch
import schnetpack as spk
class NeighborList(spk.md.neighbor_lists.MDNeighborList):
def __init__(self, system, cutoff=None):
self.simple = spk.md.neighbor_lists.SimpleNeighborList(system, cutoff)
super(NeighborList, self).__init__(system, cutoff)
def _construct_neighbor_list(self):
self.simple._construct_neighbor_list()
neighbors = self.simple.neighbor_list.view(-1, self.system.max_n_atoms, self.system.max_n_atoms-1)
positions = self.system.positions.view(-1, self.system.max_n_atoms, 3)
n_copies = neighbors.shape[0]
n_atoms = neighbors.shape[1]
r_ij = spk.nn.neighbors.atom_distances(positions, neighbors, None)
mask = r_ij < self.cutoff
max_neighbors = int(torch.count_nonzero(mask, dim=2).max())
copy_index, atom_index, neighbor_index = torch.nonzero(mask, as_tuple=True)
cumsum = torch.cumsum(mask, dim=2)-1
target_index = cumsum[copy_index, atom_index, neighbor_index]
self.neighbor_list = torch.zeros(n_copies, n_atoms, max_neighbors, device=self.system.device, dtype=torch.int64)
self.neighbor_mask = torch.zeros(n_copies, n_atoms, max_neighbors, device=self.system.device)
self.neighbor_list[copy_index, atom_index, target_index] = neighbors[copy_index, atom_index, neighbor_index]
self.neighbor_mask[copy_index, atom_index, target_index] = 1
n_replicas = self.simple.neighbor_list.shape[0]
n_molecules = self.simple.neighbor_list.shape[1]
self.neighbor_list = self.neighbor_list.view(n_replicas, n_molecules, n_atoms, max_neighbors)
self.neighbor_mask = self.neighbor_mask.view(n_replicas, n_molecules, n_atoms, max_neighbors)
def update_neighbors(self):
self._construct_neighbor_list()
Here are some new benchmarks for the 60 atom system.
Default neighbor list, don't rebuild it: 8.2 ms Default neighbor list, rebuild each iteration: 8.4 ms "Real" neighbor list, rebuild each iteration: 9.0 ms
For the 2269 atom system:
Default neighbor list: runs out of memory Real neighbor list: 81 ms
I have benchmarks for the implementation in #18. I tried to make it as close as possible to the SchNetPack QM9 model benchmarked above. I use the same cutoff distance, number of Gaussians, and output width. Each iteration builds the neighbor list then computes the value and gradient six times to match the six layers in the model. This still isn't exactly comparable, since the real model includes other calculations in addition to the cfconv layers. Those are per-atom rather than per-interaction, though, so they should be much faster and only account for a small fraction of the computation time.
For the 60 atom system, it takes 2 ms/iteration, so roughly four times faster than SchNetPack. For the 2269 atom system it takes 86 ms/iteration, so slightly slower than SchNetPack (using the neighbor list implementation above, not the standard one).
With my latest optimizations, the 60 atom system is down to only 0.82 ms/iteration. The 2269 atom system is basically unchanged, 88 ms/iteration. Possibly there are ways I could speed that up, but it's also possible the optimal way to structure the calculation is just different for small systems than for larger ones.