hamiltorch icon indicating copy to clipboard operation
hamiltorch copied to clipboard

Conversion of model to functional breaks with internal methods

Open chrhenning opened this issue 3 years ago • 1 comments

Hey,

thanks so much for providing the repo. At the moment, it seems that the code can't deal with PyTorch models that call internal methods inside the forward method. I think it would be a great feature if such models could be supported in the future.

Here is an example model, that wouldn't work (assuming that I use the code correctly):

class SimpleMLP(torch.nn.Module):
    def __init__(self, n_in=1, n_out=1, layers=()):
        super(SimpleMLP, self).__init__()

        ### Hidden layers.
        prev_dim = n_in
        self._hidden = torch.nn.Sequential()

        for i in range(len(layers)):
            self._hidden.add_module('hidden_layer_%d' % i,
                torch.nn.Linear(prev_dim, layers[i]))
            self._hidden.add_module('relu_%d' % i, torch.nn.ReLU())
            prev_dim = layers[i]

        ### Generate linear output layer.
        h_size = prev_dim
        self._out_layer = torch.nn.Linear(h_size, n_out)

    def forward(self, x):
        return self._int_forward(x)
    
    def _int_forward(self, x):
        h = self._hidden(x)
        return self._out_layer(h)

The reference to the method _int_forward is not considered in the function hamiltorch.util.make_functional.

Again, thanks for the great repo, Christian

chrhenning avatar Aug 24 '20 16:08 chrhenning

Hi Christian,

Thanks very much for bringing this example to my attention! I will make sure that I look into that soon.

I plan on making a few code updates soon and will also try to test out your example when I push the next updates.

Thanks very much for using the repo!

All the best,

Adam

AdamCobb avatar Aug 24 '20 19:08 AdamCobb