functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Performant way to initialize an ensemble of models

Open zou3519 opened this issue 2 years ago • 2 comments

Right now, to initialize an ensemble of e.g. 350 models, we first create 350 models and then combine their states together with combine_state_for_ensemble. This leaves some performance on the table; the fastest thing we could do is initialize the combined state in one go.

This might not be too difficult to do. Idea from discussion with @Chillee is:

  • we could have torch.empty, torch.tensor, etc automatically return a repeated tensor. This could be a flag in vmap
  • then we would just be able to use vmap with randomness=different to initialize a model

zou3519 avatar Jun 24 '22 14:06 zou3519

@zou3519 I'm curious if you have any thoughts on how scheduling would work then. If all the models are being combined into a single output then you're going to waiting for the slowest model and if you have a DAG of models then again you'll be bottlenecked by slower ones.

Is there any way to associate a given model with some resources so you could say something like

combine_state_for_ensemble([m1,m2,m2], [0.2,0.8,3]) which would mean combine into a single model where m1 has 20% of a GPU and model 3 has 3 GPUs available?

msaroufim avatar Aug 16 '22 00:08 msaroufim

In order to combine the models together for ensembling with vmap, each model must call exactly the same sequence of PyTorch operations (otherwise, vmap will not work). So in that case there isn't a "slowest model", when run separately the models are expected to take roughly the same amount of time.

zou3519 avatar Aug 16 '22 14:08 zou3519