functorch
functorch copied to clipboard
make_functional + parameter sharing with a list doesn't work
Initially reported by @michaelarbel here: https://github.com/pytorch/functorch/pull/620
make_functional with parameter sharing, where the parameter sharing happens by assigning a list of parameters to a module, doesn't work.
Repro from @michaelarbel:
import torch
import torch.nn as nn
from functorch import make_functional, make_functional_with_buffers
torch.manual_seed(0)
class MyModule(nn.Module):
def __init__(self,linear):
super(MyModule,self).__init__()
self.linear = linear
self.weights = list(self.linear.parameters())
def forward(self,x):
loss = torch.sum(self.linear(x)**2)
regularization = torch.sum(self.weights[0]**2)
return loss + regularization
linear = nn.Linear(20, 30)
module = MyModule(linear)
func, params, buffers = make_functional_with_buffers(module)
data = torch.ones([20])
zeros_params = [torch.zeros_like(p, requires_grad=True) for p in params]
loss = func(zeros_params,buffers,data)
print(loss)
Expected output : 0 Current output : 10.3369