functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Train ensemble models with vmap

Open kxhit opened this issue 2 years ago • 3 comments

Thanks a lot for developing functorch!

I checked the tutorial and tried vmap to ensemble multi models. I'm curious can I use vmap to do multi-model training? For example, can I use batch_pred from vmap and directly do loss.backward(), optimiser.step() to train the ensemble model? Is there anything I should take care of? Thanks a lot!

kxhit avatar May 05 '22 22:05 kxhit

I'm curious can I use vmap to do multi-model training?

Yes!

For example, can I use batch_pred from vmap and directly do loss.backward(), optimiser.step() to train the ensemble model?

In this case, you're not actually training N models in parallel - you're effectively creating a massive "ensemble" model that you're training once (presumably you sum the outputs before calling backward). I believe you get identical gradients compared to calling it once per model, but there could be ramifications on the optimizer (since you're only using one of them). For example, you'd probably be forced to use the same LR schedule with our optimizers today.

Chillee avatar May 06 '22 01:05 Chillee

Thanks a lot for the quick reply! Yeah, we can do batch training with vmap. I found we can set different learning rates for different layers' parameters easily by using optim.add_param_group(). But it seems hard to set each model with different learning rates. The parameters for batch models are grouped to a batch tensor B_model x [tensor_size] -> [B_model, tensor_size]. I'm thinking if I want to train the batch model with different learning rates for each model. Is there a way to achieve so? Thanks a lot! Great work!

kxhit avatar May 09 '22 14:05 kxhit

Hi! I have a quick question. Is functorch able to train dynamic ensemble models in a batch training way? For example, I have several models of the same architecture with different params and training data. I wish to train them with vmap but each time only n models are involved and n is changeable for different times. Looking forward to the guidance. Thanks!

kxhit avatar Jun 15 '22 10:06 kxhit

Dear @kxhit , I am trying to do a similar thing as you mentioned. Would you mind giving me a hint about how you are doing the backward call? For my case, network parameters do not get any gradient back after calling backward()

bkoyuncu avatar Dec 01 '22 15:12 bkoyuncu

@bkoyuncu Hi, here is what I did.

  1. Init batch of models: models is a list of models with exact same structure. fmodel, params, buffers = combine_state_for_ensemble(models) [p.requires_grad_() for p in params] optimiser.add_param_group({"params": params})
  2. Get batch of prediction: batch_input is a stack of batch inputs in the first dim. batch_pred = vmap(fmodel)(params, buffers, batch_input)
  3. Backprop: batch_loss = loss(batch_pred, batch_gt) batch_loss.backward() optimiser.step() optimiser.zero_grad(set_to_none=True)
  4. update original models network params. Issue

Hope it helps!

kxhit avatar Dec 05 '22 22:12 kxhit

A link to my project vMAP where I used functorch vmap function for vectorised training of multiple tiny MLPs (NeRF). Thanks a lot to @zou3519 the amazing developer team of functorch!

kxhit avatar Mar 26 '23 12:03 kxhit

Thanks for sharing @kxhit and awesome to hear that you've found functorch to be helpful!

zou3519 avatar Mar 27 '23 14:03 zou3519