functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Add an Ensemble Module that is constructed from a list of Modules and encapsulates the necessary state

Open sinking-point opened this issue 2 years ago • 6 comments

Most of the examples I've seen use hmap at the top level, to create an 'outer' ensemble of models, or to factor out the batch dimension. However, my use case is 'inner' ensembles of modules within a larger model. This means I have to register the parameters and buffers from combine_state_for_ensemble with the parent module, which is annoying and messy.

An obvious solution is to create an Ensemble module which internally calls combine_state_for_ensemble and vmap along with storing the necessary state:

self.ens = Ensemble(my_modules, in_dims=(0, 0, 2), out_dims=(0, 0, 2))
...
x = ens(x)

Even if registering the state weren't an issue, I still think this would be a popular feature. It's more intuitive than the current method of creating ensembles.

sinking-point avatar Aug 01 '22 18:08 sinking-point

Something like this, perhaps:

class Ensemble(nn.Module):
    def __init__(self, modules, **kwargs):
        super().__init__()
        
        fmodel, self.params, self.buffers = combine_state_for_ensemble(modules)
        
        self.vmap_model = vmap(fmodel, **kwargs)
        
        for i, param in enumerate(self.params):
            self.register_parameter('param_' + str(i), nn.Parameter(param))
        
        for i, buffer in enumerate(self.buffers):
            self.register_buffer('buffer_' + str(i), nn.Buffer(buffer))
            
    def forward(self, *args, **kwargs):
        return self.vmap_model(self.params, self.buffers, *args, **kwargs)

sinking-point avatar Aug 01 '22 18:08 sinking-point

This seems convenient to have. I am not sure if this would go into functorch or in torch.nn in the long-term state, but we can certainly toss something like this into functorch to start. cc @samdow who is thinking about functional modules. Also curious to hear @jbschlosser and @albanD's opinions as torch.nn maintainers.

zou3519 avatar Aug 02 '22 14:08 zou3519

This would need to be part of a bigger plan to move things like combine_state_for_ensemble as well? Also this seems to be very vmap specific?

albanD avatar Aug 02 '22 14:08 albanD

Also this seems to be very vmap specific?

Are you suggesting that we should put the nn.Ensemble API into functorch because it is vmap specific?

zou3519 avatar Aug 02 '22 14:08 zou3519

I did wonder about this because my suggestion is not really functional. It doesn't fit with the theme of this package. However, this is the only place it can go since torch can't have functorch as a dependency. Unless we create a new package for this, I guess.

sinking-point avatar Aug 03 '22 00:08 sinking-point

Are you suggesting that we should put the nn.Ensemble API into functorch because it is vmap specific?

Not necessarily but it does sound much "higher level" than things currently in torch.nn. So not sure where it should live.

albanD avatar Aug 03 '22 15:08 albanD