torchmd-net icon indicating copy to clipboard operation
torchmd-net copied to clipboard

Optimize equivariant transformer

Open peastman opened this issue 2 years ago • 23 comments

I've started looking into how to improve the speed of the equivariant transformer. I'm opening this issue as a place to report my progress.

I'm mainly interested in the speed of running simulations with it: predicting energy and forces for a single molecule. I'm less concerned about training speed. As a starting point, I tried loading molecules of various sizes and computing forces. For the small molecules we're mainly interested in, the time is completely dominated by overhead and barely affected by the number of atoms. It ranged from about 11.5 ms for molecules with around 20 atoms, up to 13.1 ms for one with 94 atoms.

Viewing it with the PyTorch profiler shows the GPU running lots of tiny kernels with long gaps in between. The GPU is sitting idle most of the time. In contrast, if I profile it during training with a batch size of 64, the GPU usage is very high. There might still be ways to make it faster, but at least operating on batches keeps the GPU from being idle.

There are two possible approaches to making it faster: improve the PyTorch code, or just replace it with custom CUDA kernels. The latter is likely to be more effective, so that's what I'm inclined to do. The custom kernels would be added to NNPOps.

peastman avatar Nov 08 '21 21:11 peastman

Have you tried the new CUDA Graph API released in PyTorch 1.10 yet? This might already reduce the CPU overhead.

Otherwise I saw a decent inference speedup when using PyTorch's JIT compiler. It is currently a bit complicated to use due to a bug in PyTorch and because Torchscript typing was lacking some features (which were added in PyTorch 1.10 though). I wrote a workaround for the PyTorch bug in the jit_trace branch.

In order to use this you will have to replace the PredictionNetwork in the model class with a JIT traced version like so

model = create_model(...)
model.network = torch.jit.trace(self.model.network, [z, pos, batch])

where z, pos and batch are example inputs used to JIT compile the model. Unfortunately this is not compatible with models created using the main branch.

PhilippThoelke avatar Nov 08 '21 23:11 PhilippThoelke

I tried both of those, but they made little difference. Trying to wrap the call to propagate() in a CUDA graph produces an exception, so I couldn't do that. But I was able to create graphs for other parts of the calculation, like computing the embedding, and it still left large gaps between the kernels.

peastman avatar Nov 08 '21 23:11 peastman

To keep the GPU code simpler, I'd like to minimize the number of options that affect the code. TorchMD_ET has a lot of options. Is it reasonable to assume the following?

  • activation and attn_activation are "silu"
  • rbf_type is "expnorm"
  • cutoff_lower is 0
  • distance_influence is "both"

peastman avatar Nov 09 '21 23:11 peastman

As I work through the code, I'm finding some differences between how it's described in the paper and how it's implemented. I assume I should match the behavior of the existing code, but I just want to confirm that's really what's intended. Here are a couple of things I've found so far.

In ExpNormalSmearing, it multiplies by a variable called alpha that isn't mentioned in the paper. It's hardcoded to have the seemingly arbitrary value 5/(cutoff_upper-cutoff_lower). If you happen to be using the default values for the parameters (upper cutoff is 5, lower cutoff is 0), then it equals 1 and doesn't affect anything. For any nondefault values, it changes the shape of the radial basis functions. What is this for? A consequence is that the means are (I believe) initialized incorrectly. They don't actually cover the range you intend them to cover.

The code applies layer normalization at the start of each update layer, and again at the final output. This isn't present in the version described in the paper.

peastman avatar Nov 10 '21 00:11 peastman

To keep the GPU code simpler, I'd like to minimize the number of options that affect the code. TorchMD_ET has a lot of options. Is it reasonable to assume the following?

  • activation and attn_activation are "silu"
    • yes
  • rbf_type is "expnorm"
    • yes
  • cutoff_lower is 0
    • not always. We have an application where setting a lower cutoff > 0 is beneficial.
  • distance_influence is "both"
    • yes

In ExpNormalSmearing, it multiplies by a variable called alpha that isn't mentioned in the paper. It's hardcoded to have the seemingly arbitrary value 5/(cutoff_upper-cutoff_lower). If you happen to be using the default values for the parameters (upper cutoff is 5, lower cutoff is 0), then it equals 1 and doesn't affect anything. For any nondefault values, it changes the shape of the radial basis functions. What is this for? A consequence is that the means are (I believe) initialized incorrectly. They don't actually cover the range you intend them to cover.

I agree that the 5 seems a bit random here. The alpha parameter distributes the radial basis functions for nondefault cutoff values according to the distribution proposed in PhysNet. See #28 for a discussion on this.

The code applies layer normalization at the start of each update layer, and again at the final output. This isn't present in the version described in the paper.

Good point, this should be included in the paper!

PhilippThoelke avatar Nov 10 '21 08:11 PhilippThoelke

Thanks! I'm a bit confused about the role of the lower cutoff. The CosineCutoff function goes to 1 at the lower cutoff distance, and then abruptly jumps to zero. Doesn't that discontinuity cause problems? And if you have trainable RBFs (the default), there's nothing to prevent basis functions from moving to places where they significantly overlap the discontinuity.

peastman avatar Nov 10 '21 17:11 peastman

Where did you find the discontinuity? This is the cutoff function with lower cutoff 1 and upper cutoff 5. Figure_1

In case you were looking at the exponential normals with lower cutoff > 0, all basis functions are just shifted by the lower cutoff. So there is no discontinuity either, the basis functions smoothly go to zero at the lower cutoff.

PhilippThoelke avatar Nov 11 '21 10:11 PhilippThoelke

Perhaps I was misreading the code. Is there a reference describing how the lower cutoff is supposed to be implemented? The paper only describes an upper cutoff.

peastman avatar Nov 11 '21 16:11 peastman

The cutoff function with lower cutoff > 0 is implemented as image as you can see here https://github.com/compsciencelab/torchmd-net/blob/28fdcb544af651d95c7e0e1ba1ac8babc21e9d65/torchmdnet/models/utils.py#L170-L184

It is not relevant for the paper as we never use a lower cutoff there.

PhilippThoelke avatar Nov 11 '21 18:11 PhilippThoelke

Thanks!

peastman avatar Nov 11 '21 19:11 peastman

@peastman you are working just on the optimization of the equivariant transformer, aren't?

We also need the graph network, which is effectively the same architecture as SchNet. @PhilippThoelke correct? So, I just need to implement the PyTorch wrapper for https://github.com/openmm/NNPOps/tree/master/src/schnet.

raimis avatar Nov 15 '21 13:11 raimis

We also need the graph network, which is effectively the same architecture as SchNet.

Depending on some arguments, yes. The following corresponds to the original SchNet architecture: TorchMD_GN(rbf_type="gauss", trainable_rbf=False, activation="ssp", neighbor_embedding=False) The default arguments in the TorchMD_GN class correspond to an improved set of hyperparameters.

PhilippThoelke avatar Nov 15 '21 14:11 PhilippThoelke

Out of curiosity, when you said it sometimes helps to set a lower cutoff, what's the reason for that? Does it actually come from changing the functional form? Or is it just that you don't waste computation by putting RBFs at short distances where they'll never be used? If the latter, could you get the same benefit by changing only how you initialize the parameters?

peastman avatar Nov 16 '21 19:11 peastman

At close distances there is no data to train on. Having a lower cutoff produces nicer models.

On Tue, Nov 16, 2021 at 8:35 PM Peter Eastman @.***> wrote:

Out of curiosity, when you said it sometimes helps to set a lower cutoff, what's the reason for that? Does it actually come from changing the functional form? Or is it just that you don't waste computation by putting RBFs at short distances where they'll never be used? If the latter, could you get the same benefit by changing only how you initialize the parameters?

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/compsciencelab/torchmd-net/issues/45#issuecomment-970612700, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOWDKCPX7BKGBVARKDTUMKW65ANCNFSM5HTXARSQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

giadefa avatar Nov 16 '21 19:11 giadefa

In that case, just initializing the parameters so you don't put RBFs at short distances should be sufficient. That would save computation time (the version with upper and lower cutoffs takes more operations) and make the code simpler.

peastman avatar Nov 16 '21 19:11 peastman

Any thoughts on this question? Is there actually a reason to change the functional form, or is it really just a matter of not putting RBFs at short distances?

I found that the lower cutoff also changes the model in a significant, unrelated way. When accumulating interactions, you multiply the attention weight by the cutoff function.

https://github.com/torchmd/torchmd-net/blob/57b4d51f222f3f66405f6b4c71511099144ebd19/torchmdnet/models/torchmd_et.py#L308

Without a lower cutoff, the cutoff function goes to 1 at short distances, which allows self interactions. You get contributions to the sum for each atom interacting with itself. With a lower cutoff, the cutoff function goes to 0 which eliminates self interactions. Did you really intend it to have that effect? If not, which behavior did you intend (with or without self interactions)?

peastman avatar Nov 30 '21 21:11 peastman

We actually start distributing the RBFs only beyond the lower cutoff in GaussianSmearing and ExpNormalSmearing. This does not require changing the functional form and in fact we don't even use the cutoff function with a lower cutoff in these two classes.

The change in functional form is important where we want to discard interactions between atoms below the lower cutoff as in the message passing step of the graph network: https://github.com/torchmd/torchmd-net/blob/57b4d51f222f3f66405f6b4c71511099144ebd19/torchmdnet/models/torchmd_gn.py#L247-L248 Here it is not enough to rely on the RBF distribution as the filter generating network (self.net) might produce non-zero entries in the filter even if all inputs are 0.

Regarding the ET, however, it is not intended that self interactions disappear when lower cutoff is larger than 0. The model should definitely have self interactions. Maybe we could explicitly allow interactions if the distance r_ij == 0.

PhilippThoelke avatar Dec 01 '21 16:12 PhilippThoelke

It sounds like you're saying the lower cutoff should only be applied for certain purposes, not others? Here are all the ways it gets used in the ET model:

  • In the Distance class, to mask out any pairs (including self interactions) whose distance is less than the cutoff. This leads to them being ignored when computing embeddings and when accumulating interactions.
  • In ExpNormalSmearing to modify the value of alpha
  • In ExpNormalSmearing to modify where the RBFs initially get placed (but they can move during training)
  • To modify the functional form of computing the RBF values
  • In NeighborEmbedding to alter the functional form while summing over neighbors
  • In EquivariantMultiHeadAttention to alter the attention weights

Do you really want it to have all of these effects, or should it be restricted to a subset of them?

I still don't understand the reason for having the lower cutoff in the first place. @giadefa said

At close distances there is no data to train on. Having a lower cutoff produces nicer models.

That doesn't make sense to me. If you try to apply a model to cases far outside the training data, you probably won't get good results. That includes having atoms much closer together than ever happened in the training data. But having the model pretend those close contacts don't exist is not going to help in any way. When two atoms come very close together, that has a huge effect on the energy. Forcing it to instead have no effect at all won't produce anything the least bit realistic.

The correct solution, it seems to me, is just to make sure your training data includes close contacts so the model can accurately learn that they produce huge energies.

peastman avatar Dec 01 '21 17:12 peastman

We do not want to throw away self interactions. I believe the locations where this is currently happening with lower cutoffs larger 0 are your 1st, 5th and 6th points. However, it is still important that we restrict the attention mechanism to interactions between the lower and upper cutoff inside EquivariantMultiHeadAttention if we want to keep the lower cutoff as an option.

I personally haven't used the lower cutoff so far but as far as I know it has helped in other projects, however, using the graph network and not the ET. Maybe it makes sense to remove the lower cutoff option for the ET (or altogether)? @giadefa

The correct solution, it seems to me, is just to make sure your training data includes close contacts so the model can accurately learn that they produce huge energies.

I completely agree but depending on the type of application it might be very difficult to obtain data, which includes close contacts (e.g. existing MD data, which is too expensive to rebuild/augment with close contacts).

When two atoms come very close together, that has a huge effect on the energy. Forcing it to instead have no effect at all won't produce anything the least bit realistic.

That is if the neural network potential is the only component in the energy calculation. Poorly sampled regions, like close contacts in MD, can be handled by a more robust but less accurate prior model.

PhilippThoelke avatar Dec 01 '21 18:12 PhilippThoelke

So, when two beads come together the NN energy is just anywhere and so the forces. Forces can be strongly attractive at 1A and strongly repulsive at 1.2A. Ideally one would like to regularize the forces, but this involves the gradient of the gradient of the energy in the loss. A lower cutoff allows for zero forces in that region and it is the only solution we found for coarse-graining simulations.

For quantum data the lower cutoff is less important, so you might just ignore it, if it is simpler. We only use it for coarse-graining.

g

On Wed, Dec 1, 2021 at 5:02 PM Peter Eastman @.***> wrote:

It sounds like you're saying the lower cutoff should only be applied for certain purposes, not others? Here are all the ways it gets used in the ET model:

  • In the Distance class, to mask out any pairs (including self interactions) whose distance is less than the cutoff. This leads to them being ignored when computing embeddings and when accumulating interactions.
  • In ExpNormalSmearing to modify the value of alpha
  • In ExpNormalSmearing to modify where the RBFs initially get placed (but they can move during training)
  • To modify the functional form of computing the RBF values
  • In NeighborEmbedding to alter the functional form while summing over neighbors
  • In EquivariantMultiHeadAttention to alter the attention weights

Do you really want it to have all of these effects, or should it be restricted to a subset of them?

I still don't understand the reason for having the lower cutoff in the first place. @giadefa https://github.com/giadefa said

At close distances there is no data to train on. Having a lower cutoff produces nicer models.

That doesn't make sense to me. If you try to apply a model to cases far outside the training data, you probably won't get good results. That includes having atoms much closer together than ever happened in the training data. But having the model pretend those close contacts don't exist is not going to help in any way. When two atoms come very close together, that has a huge effect on the energy. Forcing it to instead have no effect at all won't produce anything the least bit realistic.

The correct solution, it seems to me, is just to make sure your training data includes close contacts so the model can accurately learn that they produce huge energies.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/torchmd/torchmd-net/issues/45#issuecomment-983841859, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOWV7CZJWUNFNM7MP43UOZIKBANCNFSM5HTXARSQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

giadefa avatar Dec 01 '21 18:12 giadefa

For me it is fine to remove the lower cutoff for ET

On Wed, Dec 1, 2021 at 6:06 PM Philipp Thölke @.***> wrote:

We do not want to throw away self interactions. I believe the locations where this is currently happening with lower cutoffs larger 0 are your 1st, 5th and 6th points. However, it is still important that we restrict the attention mechanism to interactions between the lower and upper cutoff inside EquivariantMultiHeadAttention if we want to keep the lower cutoff as an option.

I personally haven't used the lower cutoff so far but as far as I know it has helped in other projects, however, using the graph network and not the ET. Maybe it makes sense to remove the lower cutoff option for the ET (or altogether)? @giadefa https://github.com/giadefa

The correct solution, it seems to me, is just to make sure your training data includes close contacts so the model can accurately learn that they produce huge energies.

I completely agree but depending on the type of application it might be very difficult to obtain data, which includes close contacts (e.g. existing MD data, which is too expensive to rebuild/augment with close contacts).

When two atoms come very close together, that has a huge effect on the energy. Forcing it to instead have no effect at all won't produce anything the least bit realistic.

That is if the neural network potential is the only component in the energy calculation. Poorly sampled regions, like close contacts in MD, can be handled by a more robust but less accurate prior model.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/torchmd/torchmd-net/issues/45#issuecomment-983921681, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOTCJKVXCQZ2RTKLYF3UOZP3TANCNFSM5HTXARSQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

giadefa avatar Dec 01 '21 18:12 giadefa

Poorly sampled regions, like close contacts in MD, can be handled by a more robust but less accurate prior model.

Perhaps a way of handling that would be to add a regularization term encouraging the NN energy contribution to go to zero at short distances? Or just create some artificial training samples that represent very close contacts and specify the energy as equal to the energy of the prior model?

peastman avatar Dec 01 '21 18:12 peastman

We tried with the training samples at low distance but did not work, because the NN is many-body. You can fix the two-body but then you have problems with 3 body interactions going wild.

On Wed, Dec 1, 2021 at 6:18 PM Peter Eastman @.***> wrote:

Poorly sampled regions, like close contacts in MD, can be handled by a more robust but less accurate prior model.

Perhaps a way of handling that would be to add a regularization term encouraging the NN energy contribution to go to zero at short distances? Or just create some artificial training samples that represent very close contacts and specify the energy as equal to the energy of the prior model?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/torchmd/torchmd-net/issues/45#issuecomment-983932318, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOXW67DOGA3AWAVGVLTUOZRIVANCNFSM5HTXARSQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

giadefa avatar Dec 01 '21 18:12 giadefa