torchmd-net icon indicating copy to clipboard operation
torchmd-net copied to clipboard

torchscipt compatability for Ensemble

Open sef43 opened this issue 10 months ago • 6 comments

I could not make Union[ Tuple[Tensor,Tensor], Tuple[Tensor,Tensor,Tensor,Tensor]] work. It will jit.scipt but then if I try and use the model as energy,_ = model(..) it will complain:


RuntimeError: 
Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]] cannot be used as a tuple:

    def forward(self, positions):
        positions = pt.index_select(positions, 0, self.all_atom_indices).to(pt.float32) * 10 # nm --> A
        energies, _ = self.model(self.atomic_numbers, positions, batch=self.batch, q=self.total_charges)
                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

with this I can do:

energies,* _ = self.model(self.atomic_numbers, positions, batch=self.batch, q=self.total_charges)

And I can change between ensemble and single model without any changes

sef43 avatar Apr 02 '24 13:04 sef43

Please add a test to check for TorchScript compatibility. Is Ensemble supposed to be able to mimic TorchMD_Net? In that case it should really return just y, neg_dy by default, without Nones.

RaulPPelaez avatar Apr 02 '24 14:04 RaulPPelaez

No I think at this point we have given up trying to mimic TorchMD_Net class outputs.

stefdoerr avatar Apr 02 '24 14:04 stefdoerr

This works:

import torch
from torch import Tensor
from typing import Union, Tuple


class Mymod(torch.nn.Module):

    def __init__(self, return3=False):
        super(Mymod, self).__init__()
        self.return3 = return3

    def forward(
        self, z: Tensor
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        if self.return3:
            return z, z, z
        else:
            return z, z

mymod = torch.jit.script(Mymod())
z = torch.ones(10)
o1, o2 = mymod(z)
mymod = torch.jit.script(Mymod(return3=True))
o1, o2, o3 = mymod(z)
mymod = torch.jit.script(Mymod())
o1, _ = mymod(z)

RaulPPelaez avatar Apr 02 '24 14:04 RaulPPelaez

yes but I cant make this work:

import torch
from torch import Tensor
from typing import Union, Tuple


class Mymod(torch.nn.Module):

    def __init__(self, return3=False):
        super(Mymod, self).__init__()
        self.return3 = return3

    def forward(
        self, z: Tensor
    ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
        if self.return3:
            return z, z, z
        else:
            return z, z


class MymodWrapper(torch.nn.Module):
    def __init__(self):
        super(MymodWrapper, self).__init__()

        self.model = Mymod()
        
    def forward(
        self, z: Tensor
    ) -> Tensor:

        o1,_ = self.model(z)

        return o1


mymod = torch.jit.script(Mymod())
z = torch.ones(10)
o1, o2 = mymod(z)
mymod = torch.jit.script(Mymod(return3=True))
o1, o2, o3 = mymod(z)
mymod = torch.jit.script(Mymod())
o1, _ = mymod(z)

mymodwrapper = MymodWrapper()

o1 = mymodwrapper(z)

mymodwrapper = torch.jit.script(MymodWrapper())

o1 = mymodwrapper(z)
Traceback (most recent call last):
  File "/home/sfarr/torchmd-net/tests/temp.py", line 48, in <module>
    mymodwrapper = torch.jit.script(MymodWrapper())
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_script.py", line 1324, in script
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
    create_methods_and_properties_from_stubs(
  File "/home/sfarr/miniconda3/envs/tn_atm_dev/lib/python3.11/site-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(
RuntimeError: 
Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]] cannot be used as a tuple:
  File "/home/sfarr/torchmd-net/tests/temp.py", line 31
    ) -> Tensor:
    
        o1,_ = self.model(z)
               ~~~~~~~~~~~~ <--- HERE
    
        return o1

For OpenMM-torch compatability there will be wrapper code that needs to be torchscripted too

sef43 avatar Apr 02 '24 14:04 sef43

Dang! That looks like a torch bug. I do not see why one would work and the other not. Ok, lets drop TorchMD_Net compatibility then for the moment. EDIT: I opened this https://github.com/pytorch/pytorch/issues/123168

RaulPPelaez avatar Apr 02 '24 14:04 RaulPPelaez

no need for urgent merge here, we will see if it works in production simulation

sef43 avatar Apr 02 '24 14:04 sef43