functorch icon indicating copy to clipboard operation
functorch copied to clipboard

make_functional + parameter sharing with a list doesn't work

Open zou3519 opened this issue 3 years ago • 0 comments

Initially reported by @michaelarbel here: https://github.com/pytorch/functorch/pull/620

make_functional with parameter sharing, where the parameter sharing happens by assigning a list of parameters to a module, doesn't work.

Repro from @michaelarbel:

import  torch 
import torch.nn as nn
from functorch import make_functional, make_functional_with_buffers

torch.manual_seed(0)

class MyModule(nn.Module):
	def __init__(self,linear):
		super(MyModule,self).__init__()
		self.linear = linear
              
		self.weights = list(self.linear.parameters())

	def forward(self,x):
		loss = torch.sum(self.linear(x)**2)
		regularization = torch.sum(self.weights[0]**2)
		return loss + regularization

linear = nn.Linear(20, 30)
module = MyModule(linear)

func, params, buffers = make_functional_with_buffers(module)

data = torch.ones([20])
zeros_params =  [torch.zeros_like(p, requires_grad=True) for p in params]

loss = func(zeros_params,buffers,data)

print(loss)

Expected output : 0 Current output : 10.3369

zou3519 avatar Jun 22 '22 18:06 zou3519