hamiltorch
hamiltorch copied to clipboard
Conversion of model to functional breaks with internal methods
Hey,
thanks so much for providing the repo. At the moment, it seems that the code can't deal with PyTorch models that call internal methods inside the forward
method. I think it would be a great feature if such models could be supported in the future.
Here is an example model, that wouldn't work (assuming that I use the code correctly):
class SimpleMLP(torch.nn.Module):
def __init__(self, n_in=1, n_out=1, layers=()):
super(SimpleMLP, self).__init__()
### Hidden layers.
prev_dim = n_in
self._hidden = torch.nn.Sequential()
for i in range(len(layers)):
self._hidden.add_module('hidden_layer_%d' % i,
torch.nn.Linear(prev_dim, layers[i]))
self._hidden.add_module('relu_%d' % i, torch.nn.ReLU())
prev_dim = layers[i]
### Generate linear output layer.
h_size = prev_dim
self._out_layer = torch.nn.Linear(h_size, n_out)
def forward(self, x):
return self._int_forward(x)
def _int_forward(self, x):
h = self._hidden(x)
return self._out_layer(h)
The reference to the method _int_forward
is not considered in the function hamiltorch.util.make_functional
.
Again, thanks for the great repo, Christian
Hi Christian,
Thanks very much for bringing this example to my attention! I will make sure that I look into that soon.
I plan on making a few code updates soon and will also try to test out your example when I push the next updates.
Thanks very much for using the repo!
All the best,
Adam