openmm-torch
openmm-torch copied to clipboard
OpenMMException when adding TorchForce to the system
I am trying to use TorchForce to bias a simulation(box-full of waters). The torch model that calculates the CV looks for nearest neighbors of reference water molecule (within cutoff) then calculates pairwise distance between them, this is my feature. Now when adding this jitted model to openmm system it throws OpenMMException. Probably the issue is regarding grad of the tensor I'm returning from my TorchForce model.
The model used in TorchForce:
class Regres_CV2(nn.Module):
def __init__(self):
"""
input_dim : flattened input vector length
hidden1 : node number of hidden layer 1
hidden2 : node number of hidden layer 2
code : dimension of latent space
learning_rate : name suggests
thresh : thresh to compare while earlystopping
train_data : trainning dataset
val_data : validation dataset
"""
super(Regres_CV2, self).__init__()
self.input_dim = 100
self.hidden1 = 1024
self.hidden2 = 512
self.hidden3 = 100
self.hidden4 = 2
torch.manual_seed(1)
np.random.seed(1)
self.encoder = nn.Sequential(
nn.Linear(self.input_dim, self.hidden1),
nn.GELU(),
nn.Linear(self.hidden1, self.hidden2),
nn.GELU(),
nn.Linear(self.hidden2, self.hidden3),
nn.GELU(),
nn.Linear(self.hidden3, self.hidden4),
)
def transform(self, x):
mean = torch.tensor([2.71049179, 2.7858348 , 2.8624867 , 2.98495639, 3.7791522 ,
3.97462569, 4.10705615, 4.21363918, 4.30592099, 4.38936177,
4.46746004, 4.54203554, 4.61539552, 4.68864327, 4.76333205,
4.84088127, 4.92357368, 5.01360315, 5.10504332, 5.19608193,
5.28742064, 5.37936281, 5.4723322 , 5.56674506, 5.66513154,
5.77511278, 5.92270633, 6.02856036, 6.11510288, 6.19097639,
6.26000642, 6.32406077, 6.38447937, 6.44241717, 6.4985381 ,
6.55356604, 6.60778704, 6.6614738 , 6.71468639, 6.76745529,
6.81982973, 6.87181133, 6.92385802, 6.97626398, 7.02942503,
7.08359104, 7.139082 , 7.19567081, 7.25226343, 7.30630277,
7.35744459, 7.40608602, 7.45265641, 7.49777495, 7.54154908,
7.58434985, 7.62622659, 7.66738141, 7.7078045 , 7.74760026,
7.78690909, 7.82570807, 7.86406368, 7.90206851, 7.93969587,
7.97700164, 8.01400973, 8.05068654, 8.08717092, 8.12329801,
8.15917206, 8.1947259 , 8.22994791, 8.26485376, 8.29941423,
8.33366086, 8.36766181, 8.40139063, 8.43490873, 8.46820756,
8.50136546, 8.53428456, 8.56699235, 8.59918515, 8.6309053 ,
8.66224525, 8.69325112, 8.72396482, 8.75443437, 8.78466108,
8.81469268, 8.84469551, 8.87456353, 8.9043148 , 8.93397793,
8.96364157, 8.99321065, 9.02271019, 9.05217446, 9.08154678])/10.0
var = torch.tensor([0.06644448, 0.0720813 , 0.09388647, 0.15218329, 0.37277633,
0.31781363, 0.27569099, 0.24100772, 0.21193125, 0.18754059,
0.16740371, 0.15137762, 0.13928688, 0.13143184, 0.12765353,
0.12734585, 0.1298747 , 0.13416642, 0.13855183, 0.14191056,
0.1447239 , 0.1472054 , 0.15001784, 0.1522338 , 0.15287467,
0.15215965, 0.16101845, 0.1660292 , 0.16650034, 0.16420291,
0.16062643, 0.15659865, 0.15245382, 0.14878753, 0.14553498,
0.1429001 , 0.14109884, 0.14015729, 0.14015134, 0.141089 ,
0.1427195 , 0.1450639 , 0.14823157, 0.15220138, 0.15710731,
0.16332683, 0.17076869, 0.17942144, 0.18831005, 0.19544865,
0.20027134, 0.2031551 , 0.20431226, 0.20431089, 0.20333588,
0.20176965, 0.19970677, 0.19729833, 0.194629 , 0.1918973 ,
0.18898873, 0.1861282 , 0.18330681, 0.18048504, 0.17777579,
0.1751567 , 0.17269368, 0.17036853, 0.1680993 , 0.16591258,
0.16380179, 0.16170763, 0.15970761, 0.15767716, 0.15568707,
0.15367787, 0.15177837, 0.14989855, 0.14813889, 0.1464694 ,
0.14496545, 0.14371142, 0.14268205, 0.14166094, 0.14057099,
0.13943382, 0.13808576, 0.13678063, 0.13545642, 0.13413571,
0.13285683, 0.1315521 , 0.13035399, 0.12925464, 0.12816313,
0.12720239, 0.12632232, 0.12558945, 0.12489357, 0.12431238])/10.0
return (x - mean)/var
def periodic_neighbours(self, pos, maxdist, L):
"""
Finds periodic neighbors within a given maximum distance in a PyTorch tensor context.
Args:
pos (torch.Tensor): Positions of particles (N x D tensor).
maxdist (float): Maximum distance for neighbors.
L (torch.Tensor): Box dimensions (D-dimensional tensor).
Returns:
torch.Tensor: Indices of neighbor pairs (M x 2 tensor).
torch.Tensor: Distances to neighbors (M-dimensional tensor).
"""
maxdistsq = maxdist**2
rL = 1. / L # Inverse box dimensions (D-dimensional tensor)
# Calculate pairwise squared distances using broadcasting
diff = pos.unsqueeze(1) - pos.unsqueeze(0) # N x N x D
diff_wrapped = diff - L.unsqueeze(0) * torch.floor(diff * rL.unsqueeze(0) + 0.5)
distsq = torch.sum(diff_wrapped * diff_wrapped, dim=2) # N x N
# Mask out distances above the threshold and calculate square root efficiently
dists = torch.sqrt(distsq[distsq < maxdistsq])
# Efficiently collect neighbor indices using gather_nd
idx_1, idx_2 = torch.where(distsq < maxdistsq)
bonds = torch.stack([idx_1, idx_2], dim=1) # M x 2
return bonds, dists
def nnDistance(self, i, bonds, dist):
# Convert i to a 1D tensor for broadcasting
i_tensor = i
# Efficiently filter and sort distances using indexing and slicing
neighbor_idx = torch.where((bonds[:, 0] == i_tensor) | (bonds[:, 1] == i_tensor))[0]
distances = dist[neighbor_idx]
sorted_distances, _ = torch.topk(distances, 201, dim=0, largest=False, sorted=True)
return sorted_distances[1::2]
def getForOneFrame(self, bonds, dist):
num_frames = 2000
indices = torch.arange(num_frames)
features = torch.stack([self.nnDistance(i, bonds, dist) for i in indices]) # Stack tensors for each frame
return features
def forward(self, positions, boxvectors):
box = torch.tensor([float(boxvectors[i][i]) for i in range(3)])
pos = positions[::4].to("cpu")
bonds, dist = self.periodic_neighbours(pos, torch.tensor([1.0]), box)
features = self.getForOneFrame(bonds, dist).detach()# Update to handle combined tensor
x = self.transform(features).to(torch.float)
y = self.encoder(x)[:,1].sum()
return y```
The openmm simulation(with MetaD):
import os, sys
import openmm as mm
import openmm.app as app
from openmm.app import GromacsGroFile, GromacsTopFile
from openmm.app import StateDataReporter
from openmm.app import XTCReporter
from openmm.app import PME, HBonds
from openmm.unit import nanometer, kelvin, picosecond, picoseconds, bar, kilojoules_per_mole
from openmmtorch import TorchForce
from openmm.app import BiasVariable
from openmm.app.metadynamics import Metadynamics
gro = GromacsGroFile('/home/dm/Dibyendu/Projects/ICE_AE/Phase_Space_Scaling/Liquid_test/md.gro')
top = GromacsTopFile('/home/dm/Dibyendu/Projects/ICE_AE/Phase_Space_Scaling/Liquid_test/topol.top',
periodicBoxVectors=gro.getPeriodicBoxVectors(),
includeDir='/home/dm/Soft/GMX22/share/gromacs/top/')
##### Create the OpenMM System based on the topology
system = top.createSystem(nonbondedMethod=PME, nonbondedCutoff=1*nanometer, constraints=HBonds)
##### Remove MM forces
while len(system.getForces()) > 0:
system.removeForce(0)
force = TorchForce('forcemodel2.pt')
force.setUsesPeriodicBoundaryConditions(periodic=True)
psi = BiasVariable(force, -5, 5, 0.5, True)
meta = Metadynamics(system, [ psi],
250*kelvin, 1.2, 1.2*kilojoules_per_mole, 100)
platform = mm.Platform.getPlatformByName('CPU')
# Specify properties for CUDA platform (e.g., mixed precision)
#prop = dict(CudaPrecision='mixed') # Use mixed single/double precision
# Add thermostat and barostat forces to the system
system.addForce(mm.AndersenThermostat(250*kelvin, 1/picosecond))
system.addForce(mm.MonteCarloBarostat(1*bar, 250*kelvin))
# Create Langevin integrator for molecular dynamics simulation
integrator = mm.LangevinMiddleIntegrator(250*kelvin, 1/picosecond, 0.0001*picoseconds)
# Create the OpenMM Simulation object
sim = app.Simulation(top.topology, system, integrator, platform)
config = gro
sim.context.setPositions(config.positions)
sim.minimizeEnergy()
meta.step(sim, 50000)
#sim.step(50000)
reporter = StateDataReporter(file=sys.stdout, reportInterval=100, step=True, time=True, potentialEnergy=True, temperature=True)
sim.reporters.append(reporter)
The error it throws:
---------------------------------------------------------------------------
OpenMMException Traceback (most recent call last)
/tmp/ipykernel_357599/4025831078.py in ?()
32 # Create the OpenMM Simulation object
33 sim = app.Simulation(top.topology, system, integrator, platform)
34 config = gro
35 sim.context.setPositions(config.positions)
---> 36 sim.minimizeEnergy()
37 meta.step(sim, 50000)
38 #sim.step(50000)
39 reporter = StateDataReporter(file=sys.stdout, reportInterval=100, step=True, time=True, potentialEnergy=True, temperature=True)
~/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/app/simulation.py in ?(self, tolerance, maxIterations, reporter)
139 reporter : MinimizationReporter = None
140 an optional reporter to invoke after each iteration. This can be used to monitor the progress
141 of minimization or to stop minimization early.
142 """
--> 143 mm.LocalEnergyMinimizer.minimize(self.context, tolerance, maxIterations, reporter)
~/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/openmm.py in ?(context, tolerance, maxIterations, reporter)
4421 if unit.is_quantity(tolerance):
4422 tolerance = tolerance.value_in_unit(unit.kilojoules_per_mole/unit.nanometer)
4423
4424
-> 4425 return _openmm.LocalEnergyMinimizer_minimize(context, tolerance, maxIterations, reporter)
OpenMMException: Expected a proper Tensor but got None (or an undefined Tensor in C++) for argument #0 'self'
Exception raised from checked_cast_variable at /home/conda/feedstock_root/build_artifacts/libtorch_1706712676143/work/torch/csrc/autograd/VariableTypeManual.cpp:60 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xaa (0x77db8353587a in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x77db834eac7e in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libc10.so)
frame #2: <unknown function> + 0x4a04cda (0x77db2fa04cda in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #3: <unknown function> + 0x36a1f5a (0x77db2e6a1f5a in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #4: <unknown function> + 0x36a25d0 (0x77db2e6a25d0 in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #5: at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>) + 0x1f7 (0x77db2cacd127 in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #6: at::native::to(at::Tensor const&, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>) + 0xbc (0x77db2c5e311c in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #7: <unknown function> + 0x2446ea5 (0x77db2d446ea5 in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #8: at::_ops::to_dtype::call(at::Tensor const&, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>) + 0x188 (0x77db2cc44b98 in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #9: TorchPlugin::ReferenceCalcTorchForceKernel::execute(OpenMM::ContextImpl&, bool, bool) + 0xac0 (0x77db82114790 in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/libOpenMMTorchReference.so)
frame #10: OpenMM::ContextImpl::calcForcesAndEnergy(bool, bool, int) + 0xc9 (0x77db7bb20109 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #11: OpenMM::ReferenceCustomCVForce::calculateIxn(OpenMM::ContextImpl&, std::vector<OpenMM::Vec3, std::allocator<OpenMM::Vec3> >&, std::map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, double, std::less<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, double> > > const&, std::vector<OpenMM::Vec3, std::allocator<OpenMM::Vec3> >&, double*, std::map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, double, std::less<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, double> > >&) + 0xfc (0x77db7bc3363c in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #12: OpenMM::ReferenceCalcCustomCVForceKernel::execute(OpenMM::ContextImpl&, OpenMM::ContextImpl&, bool, bool) + 0x2ff (0x77db7bc0b06f in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #13: OpenMM::ContextImpl::calcForcesAndEnergy(bool, bool, int) + 0xc9 (0x77db7bb20109 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #14: OpenMM::Context::getState(int, bool, int) const + 0x122 (0x77db7bb1d772 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #15: <unknown function> + 0x18847b (0x77db7bb8847b in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #16: <unknown function> + 0x188c26 (0x77db7bb88c26 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #17: lbfgs + 0x584 (0x77db7bbe8444 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #18: OpenMM::LocalEnergyMinimizer::minimize(OpenMM::Context&, double, int, OpenMM::MinimizationReporter*) + 0x7d9 (0x77db7bb89769 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #19: <unknown function> + 0x1293ae (0x77db887293ae in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/_openmm.cpython-310-x86_64-linux-gnu.so)
frame #20: <unknown function> + 0x144468 (0x620ee873e468 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #21: _PyObject_MakeTpCall + 0x26b (0x620ee873797b in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #22: _PyEval_EvalFrameDefault + 0x54b6 (0x620ee87338c6 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #23: _PyFunction_Vectorcall + 0x6c (0x620ee873e8cc in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x4c12 (0x620ee8733022 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #25: _PyFunction_Vectorcall + 0x6c (0x620ee873e8cc in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #26: _PyEval_EvalFrameDefault + 0x72c (0x620ee872eb3c in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #27: <unknown function> + 0x1d7870 (0x620ee87d1870 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #28: PyEval_EvalCode + 0x87 (0x620ee87d17b7 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #29: <unknown function> + 0x1de9ba (0x620ee87d89ba in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #30: <unknown function> + 0x144a93 (0x620ee873ea93 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #31: _PyEval_EvalFrameDefault + 0x320 (0x620ee872e730 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #32: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #33: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #34: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #35: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #36: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #37: <unknown function> + 0x1f55f7 (0x620ee87ef5f7 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #38: <unknown function> + 0x14f3bd (0x620ee87493bd in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #39: _PyEval_EvalFrameDefault + 0x72c (0x620ee872eb3c in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #40: _PyFunction_Vectorcall + 0x6c (0x620ee873e8cc in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #41: _PyEval_EvalFrameDefault + 0x320 (0x620ee872e730 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #42: _PyFunction_Vectorcall + 0x6c (0x620ee873e8cc in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #43: _PyEval_EvalFrameDefault + 0x72c (0x620ee872eb3c in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #44: <unknown function> + 0x150402 (0x620ee874a402 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #45: PyObject_Call + 0xbc (0x620ee874ad9c in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #46: _PyEval_EvalFrameDefault + 0x2d84 (0x620ee8731194 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #47: <unknown function> + 0x150402 (0x620ee874a402 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #48: _PyEval_EvalFrameDefault + 0x13cc (0x620ee872f7dc in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #49: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #50: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #51: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #52: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #53: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #54: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #55: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #56: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #57: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #58: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #59: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #60: <unknown function> + 0x7bf6 (0x77db8ddf3bf6 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/lib-dynload/_asyncio.cpython-310-x86_64-linux-gnu.so)
frame #61: <unknown function> + 0x143d2a (0x620ee873dd2a in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #62: <unknown function> + 0x25f22c (0x620ee885922c in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #63: <unknown function> + 0xfda7b (0x620ee86f7a7b in /home/dm/Soft/miniconda3/envs/torch/bin/python)
My conda environment:
conda list:
# packages in environment at /home/dm/Soft/miniconda3/envs/torch:
#
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_kmp_llvm conda-forge
anyio 4.2.0 pyhd8ed1ab_0 conda-forge
argon2-cffi 23.1.0 pyhd8ed1ab_0 conda-forge
argon2-cffi-bindings 21.2.0 py310h2372a71_4 conda-forge
arrow 1.3.0 pyhd8ed1ab_0 conda-forge
asttokens 2.4.1 pyhd8ed1ab_0 conda-forge
async-lru 2.0.4 pyhd8ed1ab_0 conda-forge
attrs 23.2.0 pyh71513ae_0 conda-forge
babel 2.14.0 pyhd8ed1ab_0 conda-forge
beautifulsoup4 4.12.3 pyha770c72_0 conda-forge
blas 1.0 mkl
bleach 6.1.0 pyhd8ed1ab_0 conda-forge
brotli-python 1.1.0 py310hc6cd4ac_1 conda-forge
bzip2 1.0.8 hd590300_5 conda-forge
ca-certificates 2024.2.2 hbcca054_0 conda-forge
cached-property 1.5.2 hd8ed1ab_1 conda-forge
cached_property 1.5.2 pyha770c72_1 conda-forge
certifi 2024.2.2 pyhd8ed1ab_0 conda-forge
cffi 1.16.0 py310h2fee648_0 conda-forge
charset-normalizer 3.3.2 pyhd8ed1ab_0 conda-forge
comm 0.2.1 pyhd8ed1ab_0 conda-forge
cuda-version 11.8 h70ddcb2_2 conda-forge
cudatoolkit 11.8.0 h4ba93d1_13 conda-forge
cudnn 8.8.0.121 hcdd5f01_4 conda-forge
debugpy 1.8.1 py310hc6cd4ac_0 conda-forge
decorator 5.1.1 pyhd8ed1ab_0 conda-forge
defusedxml 0.7.1 pyhd8ed1ab_0 conda-forge
entrypoints 0.4 pyhd8ed1ab_0 conda-forge
exceptiongroup 1.2.0 pyhd8ed1ab_2 conda-forge
executing 2.0.1 pyhd8ed1ab_0 conda-forge
filelock 3.13.1 py310h06a4308_0
fqdn 1.5.1 pyhd8ed1ab_0 conda-forge
fsspec 2024.2.0 pyhca7485f_0 conda-forge
gmp 6.2.1 h295c915_3
gmpy2 2.1.2 py310heeb90bb_0
h11 0.14.0 pyhd8ed1ab_0 conda-forge
h2 4.1.0 pyhd8ed1ab_0 conda-forge
hpack 4.0.0 pyh9f0ad1d_0 conda-forge
httpcore 1.0.2 pyhd8ed1ab_0 conda-forge
httpx 0.26.0 pyhd8ed1ab_0 conda-forge
hyperframe 6.0.1 pyhd8ed1ab_0 conda-forge
icu 73.2 h59595ed_0 conda-forge
idna 3.6 pyhd8ed1ab_0 conda-forge
importlib-metadata 7.0.1 pyha770c72_0 conda-forge
importlib_metadata 7.0.1 hd8ed1ab_0 conda-forge
importlib_resources 6.1.1 pyhd8ed1ab_0 conda-forge
intel-openmp 2022.0.1 h06a4308_3633
ipykernel 6.29.2 pyhd33586a_0 conda-forge
ipython 8.21.0 pyh707e725_0 conda-forge
ipywidgets 8.1.2 pyhd8ed1ab_0 conda-forge
isoduration 20.11.0 pyhd8ed1ab_0 conda-forge
jedi 0.19.1 pyhd8ed1ab_0 conda-forge
jinja2 3.1.3 py310h06a4308_0
joblib 1.3.2 pyhd8ed1ab_0 conda-forge
json5 0.9.14 pyhd8ed1ab_0 conda-forge
jsonpointer 2.4 py310hff52083_3 conda-forge
jsonschema 4.21.1 pyhd8ed1ab_0 conda-forge
jsonschema-specifications 2023.12.1 pyhd8ed1ab_0 conda-forge
jsonschema-with-format-nongpl 4.21.1 pyhd8ed1ab_0 conda-forge
jupyter 1.0.0 pyhd8ed1ab_10 conda-forge
jupyter-lsp 2.2.2 pyhd8ed1ab_0 conda-forge
jupyter_client 8.6.0 pyhd8ed1ab_0 conda-forge
jupyter_console 6.6.3 pyhd8ed1ab_0 conda-forge
jupyter_core 5.7.1 py310hff52083_0 conda-forge
jupyter_events 0.9.0 pyhd8ed1ab_0 conda-forge
jupyter_server 2.12.5 pyhd8ed1ab_0 conda-forge
jupyter_server_terminals 0.5.2 pyhd8ed1ab_0 conda-forge
jupyterlab 4.1.0 pyhd8ed1ab_0 conda-forge
jupyterlab_pygments 0.3.0 pyhd8ed1ab_1 conda-forge
jupyterlab_server 2.25.2 pyhd8ed1ab_0 conda-forge
jupyterlab_widgets 3.0.10 pyhd8ed1ab_0 conda-forge
ld_impl_linux-64 2.40 h41732ed_0 conda-forge
libabseil 20230802.1 cxx17_h59595ed_0 conda-forge
libblas 3.9.0 1_h86c2bf4_netlib conda-forge
libcblas 3.9.0 5_h92ddd45_netlib conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc-ng 13.2.0 h807b86a_5 conda-forge
libgfortran-ng 13.2.0 h69a702a_5 conda-forge
libgfortran5 13.2.0 ha4646dd_5 conda-forge
libgomp 13.2.0 h807b86a_5 conda-forge
libhwloc 2.9.3 default_h554bfaf_1009 conda-forge
libiconv 1.17 hd590300_2 conda-forge
liblapack 3.9.0 5_h92ddd45_netlib conda-forge
libmagma 2.7.2 h09b5827_2 conda-forge
libmagma_sparse 2.7.2 h09b5827_2 conda-forge
libnsl 2.0.1 hd590300_0 conda-forge
libprotobuf 4.25.1 hf27288f_1 conda-forge
libsodium 1.0.18 h36c2ea0_1 conda-forge
libsqlite 3.45.1 h2797004_0 conda-forge
libstdcxx-ng 13.2.0 h7e041cc_5 conda-forge
libtorch 2.1.2 cuda118_h12fe058_301 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libuv 1.47.0 hd590300_0 conda-forge
libxcrypt 4.4.36 hd590300_1 conda-forge
libxml2 2.12.5 h232c23b_0 conda-forge
libzlib 1.2.13 hd590300_5 conda-forge
llvm-openmp 17.0.6 h4dfa4b3_0 conda-forge
magma 2.7.2 h4aca40b_2 conda-forge
markupsafe 2.1.3 py310h5eee18b_0
matplotlib-inline 0.1.6 pyhd8ed1ab_0 conda-forge
mistune 3.0.2 pyhd8ed1ab_0 conda-forge
mkl 2023.2.0 h84fe81f_50496 conda-forge
mpc 1.1.0 h10f8cd9_1
mpfr 4.0.2 hb69a4c5_1
mpmath 1.3.0 py310h06a4308_0
nbclient 0.8.0 pyhd8ed1ab_0 conda-forge
nbconvert 7.16.0 pyhd8ed1ab_0 conda-forge
nbconvert-core 7.16.0 pyhd8ed1ab_0 conda-forge
nbconvert-pandoc 7.16.0 pyhd8ed1ab_0 conda-forge
nbformat 5.9.2 pyhd8ed1ab_0 conda-forge
nccl 2.19.4.1 h6103f9b_0 conda-forge
ncurses 6.4 h59595ed_2 conda-forge
nest-asyncio 1.6.0 pyhd8ed1ab_0 conda-forge
networkx 3.1 py310h06a4308_0
notebook 7.0.7 pyhd8ed1ab_0 conda-forge
notebook-shim 0.2.3 pyhd8ed1ab_0 conda-forge
numpy 1.26.4 py310hb13e2d6_0 conda-forge
ocl-icd 2.3.1 h7f98852_0 conda-forge
ocl-icd-system 1.0.0 1 conda-forge
openmm 8.1.1 py310h43b6314_1 conda-forge
openmm-torch 1.4 cuda118py310hde6f947_3 conda-forge
openssl 3.2.1 hd590300_0 conda-forge
overrides 7.7.0 pyhd8ed1ab_0 conda-forge
packaging 23.2 pyhd8ed1ab_0 conda-forge
pandoc 3.1.11.1 ha770c72_0 conda-forge
pandocfilters 1.5.0 pyhd8ed1ab_0 conda-forge
parso 0.8.3 pyhd8ed1ab_0 conda-forge
pexpect 4.9.0 pyhd8ed1ab_0 conda-forge
pickleshare 0.7.5 py_1003 conda-forge
pip 24.0 pyhd8ed1ab_0 conda-forge
pkgutil-resolve-name 1.3.10 pyhd8ed1ab_1 conda-forge
platformdirs 4.2.0 pyhd8ed1ab_0 conda-forge
prometheus_client 0.19.0 pyhd8ed1ab_0 conda-forge
prompt-toolkit 3.0.42 pyha770c72_0 conda-forge
prompt_toolkit 3.0.42 hd8ed1ab_0 conda-forge
psutil 5.9.8 py310h2372a71_0 conda-forge
ptyprocess 0.7.0 pyhd3deb0d_0 conda-forge
pure_eval 0.2.2 pyhd8ed1ab_0 conda-forge
pycparser 2.21 pyhd8ed1ab_0 conda-forge
pygments 2.17.2 pyhd8ed1ab_0 conda-forge
pysocks 1.7.1 pyha2e5f31_6 conda-forge
python 3.10.13 hd12c33a_1_cpython conda-forge
python-dateutil 2.8.2 pyhd8ed1ab_0 conda-forge
python-fastjsonschema 2.19.1 pyhd8ed1ab_0 conda-forge
python-json-logger 2.0.7 pyhd8ed1ab_0 conda-forge
python_abi 3.10 4_cp310 conda-forge
pytorch 2.1.2 cuda118_py310h59774e7_301 conda-forge
pytorch-mutex 1.0 cpu pytorch
pytz 2024.1 pyhd8ed1ab_0 conda-forge
pyyaml 6.0.1 py310h2372a71_1 conda-forge
pyzmq 25.1.2 py310h795f18f_0 conda-forge
qtconsole-base 5.5.1 pyha770c72_0 conda-forge
qtpy 2.4.1 pyhd8ed1ab_0 conda-forge
readline 8.2 h8228510_1 conda-forge
referencing 0.33.0 pyhd8ed1ab_0 conda-forge
requests 2.31.0 pyhd8ed1ab_0 conda-forge
rfc3339-validator 0.1.4 pyhd8ed1ab_0 conda-forge
rfc3986-validator 0.1.1 pyh9f0ad1d_0 conda-forge
rpds-py 0.17.1 py310hcb5633a_0 conda-forge
scikit-learn 1.4.0 py310h1fdf081_0 conda-forge
scipy 1.12.0 py310hb13e2d6_2 conda-forge
send2trash 1.8.2 pyh41d4057_0 conda-forge
setuptools 69.0.3 pyhd8ed1ab_0 conda-forge
six 1.16.0 pyh6c4a22f_0 conda-forge
sleef 3.5.1 h9b69904_2 conda-forge
sniffio 1.3.0 pyhd8ed1ab_0 conda-forge
soupsieve 2.5 pyhd8ed1ab_1 conda-forge
stack_data 0.6.2 pyhd8ed1ab_0 conda-forge
sympy 1.12 py310h06a4308_0
tbb 2021.11.0 h00ab1b0_1 conda-forge
terminado 0.18.0 pyh0d859eb_0 conda-forge
threadpoolctl 3.2.0 pyha21a80b_0 conda-forge
tinycss2 1.2.1 pyhd8ed1ab_0 conda-forge
tk 8.6.13 noxft_h4845f30_101 conda-forge
tomli 2.0.1 pyhd8ed1ab_0 conda-forge
tornado 6.3.3 py310h2372a71_1 conda-forge
traitlets 5.14.1 pyhd8ed1ab_0 conda-forge
types-python-dateutil 2.8.19.20240106 pyhd8ed1ab_0 conda-forge
typing-extensions 4.9.0 py310h06a4308_1
typing_extensions 4.9.0 py310h06a4308_1
typing_utils 0.1.0 pyhd8ed1ab_0 conda-forge
tzdata 2024a h0c530f3_0 conda-forge
uri-template 1.3.0 pyhd8ed1ab_0 conda-forge
urllib3 2.2.0 pyhd8ed1ab_0 conda-forge
wcwidth 0.2.13 pyhd8ed1ab_0 conda-forge
webcolors 1.13 pyhd8ed1ab_0 conda-forge
webencodings 0.5.1 pyhd8ed1ab_2 conda-forge
websocket-client 1.7.0 pyhd8ed1ab_0 conda-forge
wheel 0.42.0 pyhd8ed1ab_0 conda-forge
widgetsnbextension 4.0.10 pyhd8ed1ab_0 conda-forge
xz 5.2.6 h166bdaf_0 conda-forge
yaml 0.2.5 h7f98852_2 conda-forge
zeromq 4.3.5 h59595ed_0 conda-forge
zipp 3.17.0 pyhd8ed1ab_0 conda-forge
zstd 1.5.5 hfc55251_0 conda-forge
Used mamba to install openmm-torch. I suppose this is a question and not likely a bug. It will be really helpful you can find where am I making mistake!
Please let me know if you require any more information.
I believe this is an error with your model. In particular because lines like this:
pos = positions[::4].to("cpu")
...
y = self.encoder(x)[:,1].sum()
Pytorch does not like when you run backwards only on a subset of the output. To test this try to run backwards on the model alone (no openmm or openmm-torch involved). My guess is you will see a similar error. Something like this:
import torch
pos = torch.rand(10, 3)
box = torch.eye(3) * 10
model = torch.jit.load("model.pt")
pos.requires_grad_()
y = model(pos, box)
y.backward() # compute gradients
print(pos.grad)
As a side note, the box is already passed to your model as a 3x3 pytorch tensor, you should not need to convert it. You can extract its diagonal with "box.diag()"
Thanks! I got your point. It does returns a NoneType. It makes sense. I'll try to modify the model to operable on whole system. But are there any trick to do certain operation on subset of the positions? Say I want only the oxygen atom of my system!
Also I did noticed the box.diag()
in the documentation, but was lazy to change :).
You want your TorchForce to act only on a subset of the system? As an easy workaround you can just multiply by zero the ones you do not want, right?
Yes. Thanks! I'll try that.