torchmd-net
torchmd-net copied to clipboard
torchscipt compatability for Ensemble
I could not make Union[ Tuple[Tensor,Tensor], Tuple[Tensor,Tensor,Tensor,Tensor]]
work.
It will jit.scipt but then if I try and use the model as
energy,_ = model(..) it will complain:
RuntimeError:
Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]] cannot be used as a tuple:
def forward(self, positions):
positions = pt.index_select(positions, 0, self.all_atom_indices).to(pt.float32) * 10 # nm --> A
energies, _ = self.model(self.atomic_numbers, positions, batch=self.batch, q=self.total_charges)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
with this I can do:
energies,* _ = self.model(self.atomic_numbers, positions, batch=self.batch, q=self.total_charges)
And I can change between ensemble and single model without any changes
Please add a test to check for TorchScript compatibility. Is Ensemble supposed to be able to mimic TorchMD_Net? In that case it should really return just y, neg_dy by default, without Nones.
No I think at this point we have given up trying to mimic TorchMD_Net class outputs.
This works:
import torch
from torch import Tensor
from typing import Union, Tuple
class Mymod(torch.nn.Module):
def __init__(self, return3=False):
super(Mymod, self).__init__()
self.return3 = return3
def forward(
self, z: Tensor
) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
if self.return3:
return z, z, z
else:
return z, z
mymod = torch.jit.script(Mymod())
z = torch.ones(10)
o1, o2 = mymod(z)
mymod = torch.jit.script(Mymod(return3=True))
o1, o2, o3 = mymod(z)
mymod = torch.jit.script(Mymod())
o1, _ = mymod(z)
yes but I cant make this work:
import torch
from torch import Tensor
from typing import Union, Tuple
class Mymod(torch.nn.Module):
def __init__(self, return3=False):
super(Mymod, self).__init__()
self.return3 = return3
def forward(
self, z: Tensor
) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
if self.return3:
return z, z, z
else:
return z, z
class MymodWrapper(torch.nn.Module):
def __init__(self):
super(MymodWrapper, self).__init__()
self.model = Mymod()
def forward(
self, z: Tensor
) -> Tensor:
o1,_ = self.model(z)
return o1
mymod = torch.jit.script(Mymod())
z = torch.ones(10)
o1, o2 = mymod(z)
mymod = torch.jit.script(Mymod(return3=True))
o1, o2, o3 = mymod(z)
mymod = torch.jit.script(Mymod())
o1, _ = mymod(z)
mymodwrapper = MymodWrapper()
o1 = mymodwrapper(z)
mymodwrapper = torch.jit.script(MymodWrapper())
o1 = mymodwrapper(z)
Traceback (most recent call last):
File "/home/sfarr/torchmd-net/tests/temp.py", line 48, in <module>
mymodwrapper = torch.jit.script(MymodWrapper())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_script.py", line 1324, in script
return torch.jit._recursive.create_script_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
create_methods_and_properties_from_stubs(
File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
concrete_type._create_methods_and_properties(
RuntimeError:
Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]] cannot be used as a tuple:
File "/home/sfarr/torchmd-net/tests/temp.py", line 31
) -> Tensor:
o1,_ = self.model(z)
~~~~~~~~~~~~ <--- HERE
return o1
For OpenMM-torch compatability there will be wrapper code that needs to be torchscripted too
Dang! That looks like a torch bug. I do not see why one would work and the other not. Ok, lets drop TorchMD_Net compatibility then for the moment. EDIT: I opened this https://github.com/pytorch/pytorch/issues/123168
no need for urgent merge here, we will see if it works in production simulation