functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Use functional models inside usual nn.Module

Open subho406 opened this issue 1 year ago • 1 comments

Hi, Thanks for the adding functional features to Pytorch. I want to use a nn.Module converted into a functional form inside a usual stateful nn.Module. However, the code below does not correctly register the parameters for the functional module. Is there a way to do this currently?

import torch
import optree
import torch.nn as nn
from functorch import make_functional

x = torch.randn(4, 10)
class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()
        self.func_l,self.params_l=make_functional(nn.Linear(10,10))
        for i,ele in enumerate(self.params_l):
            self.register_parameter(str(i),ele)
    def forward(self,inputs):
        return self.func_l(self.params_l,inputs)
        
model = TinyModel()
func, params = make_functional(model)

This is useful for me as I want to use functional operations over an inner nn.Module (such as vmap, jvp, vip) inside the forward pass of an outer nn.Module. The idea is to be able to have a lifted version of vjp, jvp, etc, similar to Flax (https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.vjp.html).

subho406 avatar Feb 15 '23 08:02 subho406

I figured a way to do this. Here is a sample code:

class LinearModule(torch.nn.Module):
    def __init__(self):
        super(LinearModule, self).__init__()
        self.model,params=functorch.make_functional(torch.nn.Linear(10,20))
        self.params=torch.nn.ParameterList(params)
    
    def forward(self,inputs):
        return self.model(self.params,inputs)

subho406 avatar Feb 18 '23 09:02 subho406