functorch
functorch copied to clipboard
Use functional models inside usual nn.Module
Hi, Thanks for the adding functional features to Pytorch. I want to use a nn.Module
converted into a functional form inside a usual stateful nn.Module
. However, the code below does not correctly register the parameters for the functional module. Is there a way to do this currently?
import torch
import optree
import torch.nn as nn
from functorch import make_functional
x = torch.randn(4, 10)
class TinyModel(torch.nn.Module):
def __init__(self):
super(TinyModel, self).__init__()
self.func_l,self.params_l=make_functional(nn.Linear(10,10))
for i,ele in enumerate(self.params_l):
self.register_parameter(str(i),ele)
def forward(self,inputs):
return self.func_l(self.params_l,inputs)
model = TinyModel()
func, params = make_functional(model)
This is useful for me as I want to use functional operations over an inner nn.Module
(such as vmap, jvp, vip) inside the forward pass of an outer nn.Module
. The idea is to be able to have a lifted version of vjp, jvp, etc, similar to Flax (https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.vjp.html).
I figured a way to do this. Here is a sample code:
class LinearModule(torch.nn.Module):
def __init__(self):
super(LinearModule, self).__init__()
self.model,params=functorch.make_functional(torch.nn.Linear(10,20))
self.params=torch.nn.ParameterList(params)
def forward(self,inputs):
return self.model(self.params,inputs)