openmm-ml icon indicating copy to clipboard operation
openmm-ml copied to clipboard

Add support for NequIP models

Open sef43 opened this issue 1 year ago • 22 comments

This PR adds in support for NequIP models to openmm-ml. There are no pre-trained models available but the model framework is well defined. This will allow users to use their own trained NequIP models in OpenMM simulations.

Also adds code to compute neighbor lists with pytorch that will be used for MACE models too. (NNPOps neighbor list can be added later)

Addresses #48 and see https://github.com/mir-group/nequip/issues/288 for further discussion.

TODO: Need to add testing but not sure how to do this cleanly in CI considering NequIP needs to be installed via pip

sef43 avatar Oct 04 '23 10:10 sef43

Can we train a Nequip model on SPICE and enable that to be usable through openmm-ml?

jchodera avatar Nov 11 '23 16:11 jchodera

Hello,

Has there been any further progress on this? I have used NequIP in LAMMPS but would like to instead use OpenMM because it is more compatible with the enhanced sampling packages that I use.

I have tried running simulations with a NequIP potential with openmm-ml in its current state, however the speed is significantly slower than in LAMMPS. Both simulations are run on a single GPU, however in LAMMPS I also use 32 cpu threads and kokkos.

I am not sure if I am doing something incorrect in running openmm-ml, but currently it is unusable for my rather simple system of 645 atoms. Is it expected for it to be slow on a system of this size in its current state?

I can provide further information if needed. Thank you so much in advance!

Best, Sam

svarner9 avatar Apr 10 '24 04:04 svarner9

@svarner9, could you try the current implementation available here? It uses the NNPOps neighbor list, so I anticipate it might be slightly faster for a system of the size you're working with. You can create the MLPotential using something along these lines:

potential = MLPotential('nequip', modelPath='model.pth', lengthScale=0.1, energyScale=4.184)

What speed-up did you observe in your LAMMPS simulations compared to OpenMM/OpenMM-ML?

JMorado avatar Apr 23 '24 11:04 JMorado

Just for the records, I'm posting here a comparison between the energies I get when using this OpenMM-ML interface and the ASE-like NequIPCalculator. The script I'm using to calculate the energies is the following:

import openmm as mm
import openmm.app as app
import openmm.unit as unit
from ase import Atoms
from nequip.ase.nequip_calculator import NequIPCalculator
from openmmml import MLPotential

lengthScale = 0.1  # Angstrom to nm
energyScale = 96.4853075  # eV to kJ/mol
model = "si-deployed.pth"
pdb_file = "si.pdb"

# Calculate the energy using a NequIPCalculator
pdb = app.PDBFile(pdb_file)
calculator = NequIPCalculator.from_deployed_model(model)
atoms_string = "".join([atom.element.symbol for atom in pdb.topology.atoms()])
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.angstrom)
cell = (
    pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.angstrom)
    if pdb.topology.getPeriodicBoxVectors()
    else None
)
atoms = Atoms(
    atoms_string, positions=positions, cell=cell, pbc=False if cell is None else True
)
calculator.calculate(atoms)
pot_energy = calculator.get_potential_energy()
print("NequIPCalculator energy: {}".format(pot_energy * energyScale))

# Calculate the energy using OpenMM
potential = MLPotential(
    "nequip",
    modelPath=model,
    lengthScale=lengthScale,
    energyScale=energyScale,
)

system = potential.createSystem(pdb.topology)
integrator = mm.LangevinIntegrator(
    300 * unit.kelvin, 1.0 / unit.picoseconds, 1.0 * unit.femtosecond
)
simulation = app.Simulation(pdb.topology, system, integrator)
simulation.context.setPositions(pdb.positions)

print(
    "OpenMM-ML energy: {}".format(
        simulation.context.getState(getEnergy=True)
        .getPotentialEnergy()
        .value_in_unit(unit.kilojoules_per_mole)
    )
)

Toluene (NequIP, No PBC)

NequIPCalculator energy: -710491.18525 
OpenMM-ML energy: -710491.1875

Si (Allegro, PBC)

NequIPCalculator energy: -802582.8787801563
OpenMM-ML energy: -802582.875

Values are in kJ/mol. They seem to disagree from the 3rd decimal place onwards. I have checked, and the same input data is being passed to the model.

input_data.zip

JMorado avatar Apr 24 '24 15:04 JMorado

They agree to eight significant digits, which is the accuracy of single precision. Do the forces have similar agreement? If so, I think it's fine.

Is there an option to predict a formation energy instead of total energy, or to subtract off per-atom mean energies? That leads to a much smaller output value and better accuracy.

peastman avatar Apr 24 '24 15:04 peastman

That's true, thanks for pointing that out. Regarding the forces, this is what I get (values in kJ/mol/nm):

Toluene
NequIPCalculator forces:
 [[  318.61404   -1153.1539      783.135    ]
 [  379.53235     455.95953    -261.0112   ]
 [ 1114.3433     1182.2357      163.06062  ]
 [ -266.91818    1380.0348      146.71664  ]
 [ -857.644       244.97173     -12.799915 ]
 [ -966.69995   -1690.4469     -168.81912  ]
 [  257.23373    -497.74158     -25.056156 ]
 [  224.72466     211.90524    -195.86229  ]
 [  495.38205     300.3262     -238.29967  ]
 [   71.34846     423.93893     -45.857525 ]
 [ -183.62227     130.88791      -7.647993 ]
 [ -107.9876     -119.39723     -18.707212 ]
 [ -366.71957    -163.53647     -29.577057 ]
 [ -106.718575    -82.88132     -16.851692 ]
 [   -4.8681235  -623.1028      -72.42232  ]]

OpenMM-ML forces:
 [[  318.61157227 -1153.1529541    783.13665771]
 [  379.53060913   455.96350098  -261.01147461]
 [ 1114.33947754  1182.234375     163.06036377]
 [ -266.9078064   1380.03112793   146.71669006]
 [ -857.64770508   244.96878052   -12.80025387]
 [ -966.69647217 -1690.44335938  -168.81869507]
 [  257.23568726  -497.74768066   -25.05669785]
 [  224.72613525   211.90498352  -195.86283875]
 [  495.38165283   300.32611084  -238.29972839]
 [   71.35070801   423.93704224   -45.85838318]
 [ -183.62208557   130.88764954    -7.64784527]
 [ -107.99130249  -119.38985443   -18.70673943]
 [ -366.71969604  -163.53736877   -29.57713509]
 [ -106.72071075   -82.88322449   -16.85195541]
 [   -4.87014723  -623.09869385   -72.42201233]]

Difference:
[[ 2.47192383e-03 -9.76562500e-04 -1.64794922e-03]
 [ 1.73950195e-03 -3.96728516e-03  2.74658203e-04]
 [ 3.78417969e-03  1.34277344e-03  2.59399414e-04]
 [-1.03759766e-02  3.66210938e-03 -4.57763672e-05]
 [ 3.72314453e-03  2.94494629e-03  3.38554382e-04]
 [-3.47900391e-03 -3.54003906e-03 -4.27246094e-04]
 [-1.95312500e-03  6.10351562e-03  5.41687012e-04]
 [-1.48010254e-03  2.59399414e-04  5.49316406e-04]
 [ 3.96728516e-04  9.15527344e-05  6.10351562e-05]
 [-2.25067139e-03  1.89208984e-03  8.58306885e-04]
 [-1.83105469e-04  2.59399414e-04 -1.47819519e-04]
 [ 3.70025635e-03 -7.37762451e-03 -4.73022461e-04]
 [ 1.22070312e-04  9.00268555e-04  7.82012939e-05]
 [ 2.13623047e-03  1.90734863e-03  2.63214111e-04]
 [ 2.02369690e-03 -4.08935547e-03 -3.05175781e-04]]
Si
NequIPCalculator forces:
 [[ 67.79043   -28.809738   11.53376  ]
 [ -4.886628  -25.532581  -15.155592 ]
 [  1.4763676 -14.482292   19.402218 ]
 [ 17.322676  -69.07709   -25.55506  ]
 [ 19.82048    33.217815  -20.82479  ]
 [-45.851116   17.135185   -9.709675 ]
 [  4.5035777 -24.29101   -11.166489 ]
 [  1.9369186  -1.6387768  -1.4257112]
 [ 13.536634  -22.138472   31.538412 ]
 [-14.554014   11.717597  -19.121832 ]
 [-15.295826   35.69589    -3.7766256]
 [ 27.920511  -61.219616  -35.173405 ]
 [ 31.38432    44.064106  -10.46437  ]
 [-15.414515   15.491039    6.3312597]
 [-10.714798   -2.390285    7.777393 ]
 [  9.10093    21.255102   -8.837459 ]
 [-12.911689   -5.2226615  65.043106 ]
 [ 35.767906   11.211081  -12.875771 ]
 [ 60.736385  -18.289862    3.730354 ]
 [ -8.458305    3.356562   -4.8178754]
 [-13.878986   18.876963   17.74003  ]
 [-24.694405  -32.99745    20.24441  ]
 [-34.090267  -15.701595   10.985336 ]
 [-16.814713  -11.162299   42.942413 ]
 [ -3.4840176  -5.062717  -13.725371 ]
 [-33.165398    5.9761963 -18.1375   ]
 [ 48.521038  -13.241893   17.688929 ]
 [-10.354681   -2.148305  -25.099829 ]
 [-37.356796   44.274803  -34.508373 ]
 [ 44.269165   -5.420011  -13.364778 ]
 [  2.4421818  96.35251    26.062864 ]
 [ -9.241314   -2.5163426 -11.17031  ]
 [101.58642     9.112973    6.6917353]
 [ 45.18976   -18.290195  -23.636248 ]
 [-68.591415  -39.544487   61.845932 ]
 [-37.37915   -17.486694   51.09612  ]
 [-19.860252    0.7089257 -34.855164 ]
 [-34.38795    43.093174  -24.385368 ]
 [-96.76243    33.55584   -40.637005 ]
 [ 25.740808  -18.035166  -53.522533 ]
 [-17.347645  -14.707738   20.247072 ]
 [-27.524693  -50.164726   47.998127 ]
 [ 78.78394     5.531705   24.59482  ]
 [ 42.468536   46.616627  -52.593685 ]
 [ -6.6523476  16.796291  -87.7328   ]
 [  3.4838438 -38.520428    8.806853 ]
 [  2.1332002  37.900658   39.479454 ]
 [ 36.56353   -38.588394   19.898565 ]
 [ 27.694416  -80.263596   15.579612 ]
 [ 35.44082    12.952968  -50.18059  ]
 [-49.62549    26.632977   29.234938 ]
 [ 49.63715   -61.17182    70.60961  ]
 [  7.865634  -16.822647  -13.332666 ]
 [-86.48338    88.669395  -49.876156 ]
 [ 47.17594    18.837576   -2.4321811]
 [ -6.342099   15.388432   21.146124 ]
 [ 38.588196   27.882034   34.625492 ]
 [-20.583471   14.237654   -1.0932204]
 [-23.4871     72.94298    -3.7524729]
 [-11.276121   39.70276   -30.83238  ]
 [-14.0973425 -10.143854  -37.7538   ]
 [  7.0738263 -61.706738  -17.025831 ]
 [-60.275654  -20.867666   60.231632 ]
 [-44.111538  -21.53065    25.44634  ]]

OpenMM-ML forces:
 [[ 67.79095459 -28.80946541  11.53389549]
 [ -4.88648033 -25.5328598  -15.15541744]
 [  1.47649062 -14.48234749  19.40254593]
 [ 17.32279015 -69.07706451 -25.55541039]
 [ 19.82018471  33.21785736 -20.82455444]
 [-45.85109329  17.13524628  -9.70964622]
 [  4.50365496 -24.29071236 -11.16631031]
 [  1.9371053   -1.63878345  -1.4259665 ]
 [ 13.53680134 -22.13811111  31.53835106]
 [-14.55395412  11.71772671 -19.12192345]
 [-15.29564571  35.69577789  -3.77657413]
 [ 27.92069817 -61.21953964 -35.17324066]
 [ 31.38408089  44.06418228 -10.46452713]
 [-15.41459084  15.49117565   6.33090544]
 [-10.71499538  -2.39058208   7.77688837]
 [  9.10089684  21.25516891  -8.83750343]
 [-12.91157341  -5.22240162  65.04360199]
 [ 35.76815414  11.2110014  -12.87581348]
 [ 60.73667526 -18.28949165   3.73050475]
 [ -8.45816994   3.35656095  -4.81774998]
 [-13.87901878  18.87675095  17.73980713]
 [-24.69457817 -32.99786758  20.24431419]
 [-34.09024811 -15.70182705  10.98557568]
 [-16.81512642 -11.16264534  42.94252014]
 [ -3.48382878  -5.06294775 -13.72527027]
 [-33.1651001    5.97644854 -18.13700104]
 [ 48.52108383 -13.24161148  17.68883514]
 [-10.35446072  -2.14804649 -25.0998497 ]
 [-37.35691833  44.27435303 -34.50836182]
 [ 44.26918411  -5.42039633 -13.36488342]
 [  2.44198561  96.35207367  26.06266022]
 [ -9.24124241  -2.51659489 -11.17035961]
 [101.58667755   9.11304665   6.69179535]
 [ 45.1896286  -18.29023361 -23.63647842]
 [-68.59138489 -39.54423523  61.84625626]
 [-37.37944412 -17.48635101  51.09622955]
 [-19.86007118   0.70900166 -34.85506821]
 [-34.38783264  43.09328079 -24.38536835]
 [-96.76257324  33.55570602 -40.63653564]
 [ 25.74025154 -18.03463173 -53.52237701]
 [-17.34791756 -14.70756245  20.24714851]
 [-27.52493286 -50.16454697  47.99803543]
 [ 78.78383636   5.53193188  24.59463501]
 [ 42.46936798  46.61732101 -52.59399033]
 [ -6.65289164  16.7964077  -87.73320007]
 [  3.48389125 -38.52040863   8.80679893]
 [  2.1330018   37.9006958   39.47941971]
 [ 36.56356049 -38.58841705  19.89873314]
 [ 27.69441986 -80.26367188  15.57967281]
 [ 35.44096756  12.95310497 -50.180439  ]
 [-49.62584305  26.63298416  29.23537064]
 [ 49.63778305 -61.17185593  70.609375  ]
 [  7.86521959 -16.82258034 -13.33259678]
 [-86.48326874  88.66952515 -49.87618256]
 [ 47.17575455  18.83693504  -2.43251538]
 [ -6.34226656  15.38792515  21.14606094]
 [ 38.58839035  27.88191795  34.62520218]
 [-20.58321762  14.23798943  -1.09339654]
 [-23.48657036  72.9430542   -3.75274563]
 [-11.27591801  39.70261765 -30.83240891]
 [-14.0974226  -10.14392471 -37.75414658]
 [  7.07348776 -61.70677948 -17.02573586]
 [-60.27626038 -20.86861038  60.23205185]
 [-44.11212921 -21.53063011  25.44634056]]

Difference:
[[-5.26428223e-04 -2.72750854e-04 -1.35421753e-04]
 [-1.47819519e-04  2.78472900e-04 -1.74522400e-04]
 [-1.23023987e-04  5.53131104e-05 -3.28063965e-04]
 [-1.14440918e-04 -2.28881836e-05  3.50952148e-04]
 [ 2.95639038e-04 -4.19616699e-05 -2.34603882e-04]
 [-2.28881836e-05 -6.10351562e-05 -2.86102295e-05]
 [-7.72476196e-05 -2.97546387e-04 -1.78337097e-04]
 [-1.86681747e-04  6.67572021e-06  2.55346298e-04]
 [-1.66893005e-04 -3.60488892e-04  6.10351562e-05]
 [-6.00814819e-05 -1.29699707e-04  9.15527344e-05]
 [-1.80244446e-04  1.10626221e-04 -5.14984131e-05]
 [-1.86920166e-04 -7.62939453e-05 -1.64031982e-04]
 [ 2.38418579e-04 -7.62939453e-05  1.57356262e-04]
 [ 7.62939453e-05 -1.36375427e-04  3.54290009e-04]
 [ 1.97410583e-04  2.97069550e-04  5.04493713e-04]
 [ 3.33786011e-05 -6.67572021e-05  4.48226929e-05]
 [-1.15394592e-04 -2.59876251e-04 -4.95910645e-04]
 [-2.47955322e-04  7.91549683e-05  4.29153442e-05]
 [-2.89916992e-04 -3.70025635e-04 -1.50680542e-04]
 [-1.35421753e-04  9.53674316e-07 -1.25408173e-04]
 [ 3.24249268e-05  2.11715698e-04  2.23159790e-04]
 [ 1.73568726e-04  4.15802002e-04  9.53674316e-05]
 [-1.90734863e-05  2.31742859e-04 -2.39372253e-04]
 [ 4.13894653e-04  3.46183777e-04 -1.06811523e-04]
 [-1.88827515e-04  2.30789185e-04 -1.01089478e-04]
 [-2.97546387e-04 -2.52246857e-04 -4.99725342e-04]
 [-4.57763672e-05 -2.81333923e-04  9.34600830e-05]
 [-2.20298767e-04 -2.58445740e-04  2.09808350e-05]
 [ 1.22070312e-04  4.50134277e-04 -1.14440918e-05]
 [-1.90734863e-05  3.85284424e-04  1.05857849e-04]
 [ 1.96218491e-04  4.34875488e-04  2.04086304e-04]
 [-7.15255737e-05  2.52246857e-04  4.95910645e-05]
 [-2.59399414e-04 -7.34329224e-05 -6.00814819e-05]
 [ 1.29699707e-04  3.81469727e-05  2.30789185e-04]
 [-3.05175781e-05 -2.51770020e-04 -3.24249268e-04]
 [ 2.93731689e-04 -3.43322754e-04 -1.10626221e-04]
 [-1.81198120e-04 -7.59363174e-05 -9.53674316e-05]
 [-1.18255615e-04 -1.06811523e-04  0.00000000e+00]
 [ 1.44958496e-04  1.33514404e-04 -4.69207764e-04]
 [ 5.56945801e-04 -5.34057617e-04 -1.56402588e-04]
 [ 2.72750854e-04 -1.75476074e-04 -7.62939453e-05]
 [ 2.40325928e-04 -1.79290771e-04  9.15527344e-05]
 [ 1.06811523e-04 -2.26974487e-04  1.85012817e-04]
 [-8.31604004e-04 -6.94274902e-04  3.05175781e-04]
 [ 5.44071198e-04 -1.16348267e-04  3.96728516e-04]
 [-4.74452972e-05 -1.90734863e-05  5.43594360e-05]
 [ 1.98364258e-04 -3.81469727e-05  3.43322754e-05]
 [-3.05175781e-05  2.28881836e-05 -1.67846680e-04]
 [-3.81469727e-06  7.62939453e-05 -6.10351562e-05]
 [-1.48773193e-04 -1.37329102e-04 -1.52587891e-04]
 [ 3.54766846e-04 -7.62939453e-06 -4.32968140e-04]
 [-6.33239746e-04  3.43322754e-05  2.36511230e-04]
 [ 4.14371490e-04 -6.67572021e-05 -6.96182251e-05]
 [-1.14440918e-04 -1.29699707e-04  2.67028809e-05]
 [ 1.86920166e-04  6.40869141e-04  3.34262848e-04]
 [ 1.67369843e-04  5.06401062e-04  6.29425049e-05]
 [-1.94549561e-04  1.16348267e-04  2.89916992e-04]
 [-2.53677368e-04 -3.35693359e-04  1.76191330e-04]
 [-5.30242920e-04 -7.62939453e-05  2.72750854e-04]
 [-2.03132629e-04  1.41143799e-04  2.86102295e-05]
 [ 8.01086426e-05  7.05718994e-05  3.47137451e-04]
 [ 3.38554382e-04  4.19616699e-05 -9.53674316e-05]
 [ 6.06536865e-04  9.44137573e-04 -4.19616699e-04]
 [ 5.91278076e-04 -1.90734863e-05  0.00000000e+00]]

As far as I know, there's no option to get the interaction energy or the per-atom mean energies. This atomic energy sums up to the total energy

JMorado avatar Apr 24 '24 16:04 JMorado

This is done from my side. If someone could take a look and review the changes, that would be great. Performance benchmarks on test models can be found here.

Many thanks!

JMorado avatar May 07 '24 10:05 JMorado

@svarner9, could you try the current implementation available here? It uses the NNPOps neighbor list, so I anticipate it might be slightly faster for a system of the size you're working with. You can create the MLPotential using something along these lines:

potential = MLPotential('nequip', modelPath='model.pth', lengthScale=0.1, energyScale=4.184)

What speed-up did you observe in your LAMMPS simulations compared to OpenMM/OpenMM-ML?

@JMorado I went ahead and tested out the version on the nequip branch, however I am unable to get it to run on a GPU. When I specify the potential and the platform in the following way,

potential = MLPotential("nequip",
                            modelPath='model.pth',
                            lengthScale=0.1,
                            energyScale=96.48)
...

plat = openmm.Platform.getPlatformByName("CUDA")
properties = {"Precision": "double", "DeviceIndex": "0",
              "UseBlockingSync": "false"}
simulation = app.Simulation(topology, system, integrator, plat, properties)

I get the following set of warnings and errors:

/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/torchani/aev.py:16: UserWarning: cuaev not installed
  warnings.warn("cuaev not installed")
/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/scripts/deploy.py:138: UserWarning: Models deployed before v0.6.0 don't contain information about their default_dtype or model_dtype; assuming the old default of float32 for both, but this might not be right if you had explicitly set default_dtype=float64.
  warnings.warn(
/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/utils/_global_options.py:59: UserWarning: !! Upstream issues in PyTorch versions >1.11 have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. At present we *strongly* recommend the use of PyTorch 1.11 if using CUDA devices; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue.
  warnings.warn(
/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/utils/_global_options.py:70: UserWarning: Setting the GLOBAL value for jit fusion strategy to `[('DYNAMIC', 3)]` which is different than the previous value of `[('STATIC', 2), ('DYNAMIC', 10)]`
  warnings.warn(
Traceback (most recent call last):
  File "/home/svarner/Practicum/sim.py", line 174, in <module>
    run(1,1,1,1,1)
  File "/home/svarner/Practicum/sim.py", line 145, in run
    simulation = app.Simulation(topology, system, integrator, plat, properties)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/openmm/app/simulation.py", line 106, in __init__
    self.context = mm.Context(self.system, self.integrator, platform, platformProperties)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/openmm/openmm.py", line 12171, in __init__
    _openmm.Context_swiginit(self, _openmm.new_Context(*args))
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
openmm.OpenMMException: Specified a Platform for a Context which does not support all required kernels

Here is my mamba list:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
ase                       3.22.1             pyhd8ed1ab_1    conda-forge
blinker                   1.8.2              pyhd8ed1ab_0    conda-forge
brotli                    1.1.0                hd590300_1    conda-forge
brotli-bin                1.1.0                hd590300_1    conda-forge
brotli-python             1.1.0           py311hb755f60_1    conda-forge
bzip2                     1.0.8                hd590300_5    conda-forge
c-ares                    1.28.1               hd590300_0    conda-forge
ca-certificates           2024.2.2             hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
certifi                   2024.2.2           pyhd8ed1ab_0    conda-forge
charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
click                     8.1.7           unix_pyh707e725_0    conda-forge
contourpy                 1.2.1           py311h9547e67_0    conda-forge
cudatoolkit               11.5.2              hbdc67f6_13    conda-forge
cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
e3nn                      0.5.1                    pypi_0    pypi
filelock                  3.14.0             pyhd8ed1ab_0    conda-forge
flask                     3.0.3              pyhd8ed1ab_0    conda-forge
fonttools                 4.51.0          py311h459d7ec_0    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
fsspec                    2024.3.1           pyhca7485f_0    conda-forge
gmp                       6.3.0                h59595ed_1    conda-forge
gmpy2                     2.1.5           py311he48d604_0    conda-forge
h5py                      3.11.0          nompi_py311hebc2b07_100    conda-forge
hdf5                      1.14.3          nompi_h4f84152_101    conda-forge
idna                      3.7                pyhd8ed1ab_0    conda-forge
importlib-metadata        7.1.0              pyha770c72_0    conda-forge
importlib_metadata        7.1.0                hd8ed1ab_0    conda-forge
itsdangerous              2.2.0              pyhd8ed1ab_0    conda-forge
jinja2                    3.1.3              pyhd8ed1ab_0    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.5           py311h9547e67_1    conda-forge
krb5                      1.21.2               h659d440_0    conda-forge
lark-parser               0.12.0             pyhd8ed1ab_0    conda-forge
lcms2                     2.16                 hb7c19ff_0    conda-forge
ld_impl_linux-64          2.40                 h55db66e_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libabseil                 20230802.1      cxx17_h59595ed_0    conda-forge
libaec                    1.1.3                h59595ed_0    conda-forge
libblas                   3.9.0           22_linux64_openblas    conda-forge
libbrotlicommon           1.1.0                hd590300_1    conda-forge
libbrotlidec              1.1.0                hd590300_1    conda-forge
libbrotlienc              1.1.0                hd590300_1    conda-forge
libcblas                  3.9.0           22_linux64_openblas    conda-forge
libcurl                   8.7.1                hca28451_0    conda-forge
libdeflate                1.20                 hd590300_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 hd590300_2    conda-forge
libexpat                  2.6.2                h59595ed_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.2.0               h77fa898_7    conda-forge
libgfortran-ng            13.2.0               h69a702a_7    conda-forge
libgfortran5              13.2.0               hca663fb_7    conda-forge
libgomp                   13.2.0               h77fa898_7    conda-forge
libjpeg-turbo             3.0.0                hd590300_1    conda-forge
liblapack                 3.9.0           22_linux64_openblas    conda-forge
libnghttp2                1.58.0               h47da74e_1    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libopenblas               0.3.27          pthreads_h413a1c8_0    conda-forge
libpng                    1.6.43               h2797004_0    conda-forge
libprotobuf               4.25.1               hf27288f_2    conda-forge
libsqlite                 3.45.3               h2797004_0    conda-forge
libssh2                   1.11.0               h0841786_0    conda-forge
libstdcxx-ng              13.2.0               hc0a3c3a_7    conda-forge
libtiff                   4.6.0                h1dd3fc0_3    conda-forge
libtorch                  2.1.2           cpu_generic_ha017de0_3    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libuv                     1.48.0               hd590300_0    conda-forge
libwebp-base              1.4.0                hd590300_0    conda-forge
libxcb                    1.15                 h0b41bf4_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
markupsafe                2.1.5           py311h459d7ec_0    conda-forge
matplotlib-base           3.8.4           py311h54ef318_0    conda-forge
mpc                       1.3.1                hfe3b2da_0    conda-forge
mpfr                      4.2.1                h9458935_1    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
ncurses                   6.4.20240210         h59595ed_0    conda-forge
nequip                    0.6.0                    pypi_0    pypi
networkx                  3.3                pyhd8ed1ab_1    conda-forge
nnpops                    0.6             cpu_py311h7697b17_7    conda-forge
nomkl                     1.0                  h5ca1d4c_0    conda-forge
numpy                     1.26.4          py311h64a7726_0    conda-forge
ocl-icd                   2.3.2                hd590300_1    conda-forge
ocl-icd-system            1.0.0                         1    conda-forge
openjpeg                  2.5.2                h488ebb8_0    conda-forge
openmm                    8.1.1           py311h28d7ac7_1    conda-forge
openmm-torch              1.4             cpu_py311h446247e_4    conda-forge
openmmml                  1.1                      pypi_0    pypi
openssl                   3.3.0                hd590300_0    conda-forge
opt-einsum                3.3.0                    pypi_0    pypi
opt-einsum-fx             0.1.4                    pypi_0    pypi
packaging                 24.0               pyhd8ed1ab_0    conda-forge
pillow                    10.3.0          py311h18e6fac_0    conda-forge
pip                       24.0               pyhd8ed1ab_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
pyparsing                 3.1.2              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.11.9          hb806964_0_cpython    conda-forge
python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
python_abi                3.11                    4_cp311    conda-forge
pytorch                   2.1.2           cpu_generic_py311h1584bb0_3    conda-forge
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h8228510_1    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
scipy                     1.13.0          py311h517d4fd_1    conda-forge
setuptools                65.3.0             pyhd8ed1ab_1    conda-forge
setuptools-scm            6.3.2              pyhd8ed1ab_0    conda-forge
setuptools_scm            6.3.2                hd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h9b69904_2    conda-forge
sympy                     1.12            pypyh9d50eac_103    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
torch-ema                 0.3                      pypi_0    pypi
torch-runstats            0.2.0                    pypi_0    pypi
torchani                  2.2.4           cpu_py311h12a0d1d_3    conda-forge
tqdm                      4.66.4                   pypi_0    pypi
typing_extensions         4.11.0             pyha770c72_0    conda-forge
tzdata                    2024a                h0c530f3_0    conda-forge
urllib3                   2.2.1              pyhd8ed1ab_0    conda-forge
werkzeug                  3.0.3              pyhd8ed1ab_0    conda-forge
wheel                     0.43.0             pyhd8ed1ab_1    conda-forge
xorg-libxau               1.0.11               hd590300_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zipp                      3.17.0             pyhd8ed1ab_0    conda-forge
zstd                      1.5.6                ha6fb4c9_0    conda-forge

If I don't specify any platform, then the simulation runs, but extremely slowly since it is on CPU.

Thank you so much in advance!

Best, Sam

svarner9 avatar May 08 '24 02:05 svarner9

That means a plugin couldn't be loaded. Try printing the value of Platform.getPluginLoadFailures(). It will tell you which ones failed, and what the errors were.

Usually it's because some library they depended on couldn't be found, and it can be fixed by adding the directory containing the library to LD_LIBRARY_PATH.

peastman avatar May 08 '24 03:05 peastman

That means a plugin couldn't be loaded. Try printing the value of Platform.getPluginLoadFailures(). It will tell you which ones failed, and what the errors were.

Usually it's because some library they depended on couldn't be found, and it can be fixed by adding the directory containing the library to LD_LIBRARY_PATH.

Thank you for the quick response!

I tried that based on some previous replies of yours that I found. I ran the following:

print(pluginLoadedLibNames)
print(Platform.getPluginLoadFailures())

and the output was:

('/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMPME.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMCPU.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMCUDA.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMOpenCL.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMRPMDCUDA.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMDrudeCUDA.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMAmoebaCUDA.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMRPMDOpenCL.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMTorchOpenCL.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMDrudeOpenCL.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMAmoebaOpenCL.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMRPMDReference.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMTorchReference.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMDrudeReference.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMAmoebaReference.so')

()

The failures command returned an empty tuple.

Best, Sam

svarner9 avatar May 08 '24 03:05 svarner9

The versions of PyTorch and OpenMM-Torch you have installed are CPU only:

openmm-torch              1.4             cpu_py311h446247e_4    conda-forge
pytorch                   2.1.2           cpu_generic_py311h1584bb0_3    conda-forge

That might be because you have an older version of cudatoolkit:

cudatoolkit               11.5.2              hbdc67f6_13    conda-forge

If you upgrade it to 11.8, you might be able to get it to install the CUDA version of PyTorch. Conda installation issues like this tend to be frustrating and hard to figure out. They often depend on the precise order you install packages in.

peastman avatar May 08 '24 04:05 peastman

The versions of PyTorch and OpenMM-Torch you have installed are CPU only:

openmm-torch              1.4             cpu_py311h446247e_4    conda-forge
pytorch                   2.1.2           cpu_generic_py311h1584bb0_3    conda-forge

That might be because you have an older version of cudatoolkit:

cudatoolkit               11.5.2              hbdc67f6_13    conda-forge

If you upgrade it to 11.8, you might be able to get it to install the CUDA version of PyTorch. Conda installation issues like this tend to be frustrating and hard to figure out. They often depend on the precise order you install packages in.

Ahhh I see. Thank you!

I went ahead an uninstalled openmm-torch and pytorch. I upgraded the cudatoolkit, and then installed the cuda version of pytorch:

install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia

Installing openmm-torch downgraded it back to the cpu version, but then installing nnpops upgraded it back to the cuda version. I agree, conda installations are very frustrating.

It is working on GPU now, but only getting about 0.2 ns/day, whereas on lammps I was getting 1.5 ns/day. To your knowledge, could any of the following warnings have to do with it being slow?

/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/scripts/deploy.py:138: UserWarning: Models deployed before v0.6.0 don't contain information about their default_dtype or model_dtype; assuming the old default of float32 for both, but this might not be right if you had explicitly set default_dtype=float64.
  warnings.warn(
/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/utils/_global_options.py:59: UserWarning: !! Upstream issues in PyTorch versions >1.11 have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. At present we *strongly* recommend the use of PyTorch 1.11 if using CUDA devices; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue.
  warnings.warn(
/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/utils/_global_options.py:70: UserWarning: Setting the GLOBAL value for jit fusion strategy to `[('DYNAMIC', 3)]` which is different than the previous value of `[('STATIC', 2), ('DYNAMIC', 10)]`
  warnings.warn(

I tried to install the packages in such a way to allow me to use pytorch 1.11.0 (which according to the error is the most stable version with nequip), however, as far as I can tell there is no way to use pytorch 1.11.0 with openmm-torch. Every time I would install openmm-torch it would install pytorch 2.1.2.

This is the order that I did everything:

mamba create -n env
mamba activate env
mamba install python=3.10
mamba install -c conda-forge openmm cudatoolkit=11.8
pip install git+https://github.com/mir-group/nequip@develop
pip install git+https://github.com/sef43/openmm-ml@nequip
mamba install pytorch=1.11 pytorch-cuda=11.8 -c pytorch -c nvidia
mamba install -c conda-forge openmm-torch nnpops

svarner9 avatar May 08 '24 05:05 svarner9

Many thanks for the thorough review, @peastman! Most of it should be now resolved.

Thanks for testing, @svarner9. I think the slow performance you're seeing is not related to that warning, the underlying issue of which is described here. You could test if the issue that underlies that warning is indeed present by identifying a slowdown in performance over time. I ran some performance benchmarks on systems much smaller than yours and did not see any decrease in performance over time, and the simulation speed is around what I would expect.

If that is your baseline OpenMM performance, I wonder what could be causing that. Do you remember by any chance what was the performance you were getting with the previous neighbor list? Does anyone have any ideas about whether it's possible to improve performance here?

JMorado avatar May 08 '24 12:05 JMorado

Yes many thanks @peastman for the help!

@JMorado I am not sure, but there are a few things I can think of that might be the issue, but I am not an expert and have not looked through the code, so it might be a bit naive.

  1. In LAMMPS the nequip pairstyle works with Kokkos, so in that case I was using 1 gpu + 32 cpus.
mpiexec -n 1 ./lmp -in in.script -k on g 1 t 32 -sf kk -pk kokkos newton on neigh full
  1. The LAMMPS nequip pairstyle uses libtorch instead of pytorch, which could make a difference?
  2. When reading in the model, is the cutoff set to the cutoff of the MLP? Most of them have very short cutoffs of around 5 Angstroms, so if that cutoff is not being used for neighborlists, then that could be leading to slow performance. Is that something that should be set separately?
  3. I am getting this warning for jit but I am not sure if it is important or could be affecting performance. I have seen the NequIP devs say that it can usually be silently ignored.
/home/svarner/miniconda3/envs/practicum/lib/python3.10/site-packages/nequip/utils/_global_options.py:70: UserWarning: Setting the GLOBAL value for jit fusion strategy to `[('DYNAMIC', 3)]` which is different than the previous value of `[('STATIC', 2), ('DYNAMIC', 10)]`
  warnings.warn(

Best, Sam

svarner9 avatar May 08 '24 16:05 svarner9

Is there an option to predict a formation energy instead of total energy, or to subtract off per-atom mean energies? That leads to a much smaller output value and better accuracy.

We actually do this internally, at least from develop onward---single precision calculations are done in a more numerically favorable range, and the final energy scalings, shiftings, and sums are done in float64, regardless of the precision of the weights. The final predictions you get should be float64, and if they aren't, something might be off.

Regarding the reproducibility of energies between ASE and OpenMM: you can try turning off TF32, or even better using a fully F64 model (default_dtype: float64 and model_dtype: float64) to ensure that this is just numerics as a sanity check.

Linux-cpp-lisp avatar May 10 '24 17:05 Linux-cpp-lisp

@svarner9 a few questions on performance:

  • What are the actual LAMMPS vs OpenMM numbers? Not sure where they were in this thread.
  • Yes, there will be additional Python and doubled neighborlist overhead in OpenMM, both of which are absent in pair_allegro. This should be more important for smaller models and smaller systems.
  • You can ignore that particular warning about the fusion strategy safely, it is just there to ensure that nequip never silently sets global state when called from someone else's program

Linux-cpp-lisp avatar May 10 '24 17:05 Linux-cpp-lisp

There shouldn't be any overhead from Python. The model gets compiled to torchscript, and the simulation gets run by C++ code.

peastman avatar May 10 '24 17:05 peastman

Do you call TorchScript from Python here, or directly from C++? Not that I would expect a roundtrip through Python to matter much, just curious.

Linux-cpp-lisp avatar May 10 '24 18:05 Linux-cpp-lisp

It's called directly from C++.

peastman avatar May 10 '24 18:05 peastman

@peastman @Linux-cpp-lisp, I've trained a model with these settings:

default_dtype: float64
model_dtype: float64
allow_tf32: true  

and the energy and force differences between ASE and OpenMM are indeed very small, on the order of $10^{−10}$, when combined with {"Precision": "double"} in the simulation settings.

JMorado avatar May 14 '24 16:05 JMorado

@JMorado great!

(Note that allow_tf32: true is a no-op when model_dtype: float64 and we should probably error on this configuration, but that doesn't change the results.)

Linux-cpp-lisp avatar May 15 '24 03:05 Linux-cpp-lisp

@svarner9 a few questions on performance:

  • What are the actual LAMMPS vs OpenMM numbers? Not sure where they were in this thread.
  • Yes, there will be additional Python and doubled neighborlist overhead in OpenMM, both of which are absent in pair_allegro. This should be more important for smaller models and smaller systems.
  • You can ignore that particular warning about the fusion strategy safely, it is just there to ensure that nequip never silently sets global state when called from someone else's program

@Linux-cpp-lisp I was getting 1.5 ns/day on lammps and 0.2 ns/day on openmm for a system with 645 atoms.

svarner9 avatar May 22 '24 20:05 svarner9