functorch
functorch copied to clipboard
Train ensemble models with vmap
Thanks a lot for developing functorch!
I checked the tutorial and tried vmap to ensemble multi models. I'm curious can I use vmap to do multi-model training? For example, can I use batch_pred from vmap and directly do loss.backward(), optimiser.step() to train the ensemble model? Is there anything I should take care of? Thanks a lot!
I'm curious can I use vmap to do multi-model training?
Yes!
For example, can I use batch_pred from vmap and directly do loss.backward(), optimiser.step() to train the ensemble model?
In this case, you're not actually training N models in parallel - you're effectively creating a massive "ensemble" model that you're training once (presumably you sum the outputs before calling backward). I believe you get identical gradients compared to calling it once per model, but there could be ramifications on the optimizer (since you're only using one of them). For example, you'd probably be forced to use the same LR schedule with our optimizers today.
Thanks a lot for the quick reply! Yeah, we can do batch training with vmap. I found we can set different learning rates for different layers' parameters easily by using optim.add_param_group(). But it seems hard to set each model with different learning rates. The parameters for batch models are grouped to a batch tensor B_model x [tensor_size] -> [B_model, tensor_size]. I'm thinking if I want to train the batch model with different learning rates for each model. Is there a way to achieve so? Thanks a lot! Great work!
Hi! I have a quick question. Is functorch able to train dynamic ensemble models in a batch training way? For example, I have several models of the same architecture with different params and training data. I wish to train them with vmap but each time only n models are involved and n is changeable for different times. Looking forward to the guidance. Thanks!
Dear @kxhit , I am trying to do a similar thing as you mentioned. Would you mind giving me a hint about how you are doing the backward call? For my case, network parameters do not get any gradient back after calling backward()
@bkoyuncu Hi, here is what I did.
- Init batch of models: models is a list of models with exact same structure.
fmodel, params, buffers = combine_state_for_ensemble(models)
[p.requires_grad_() for p in params]
optimiser.add_param_group({"params": params})
- Get batch of prediction: batch_input is a stack of batch inputs in the first dim.
batch_pred = vmap(fmodel)(params, buffers, batch_input)
- Backprop:
batch_loss = loss(batch_pred, batch_gt)
batch_loss.backward()
optimiser.step()
optimiser.zero_grad(set_to_none=True)
- update original models network params. Issue
Hope it helps!
A link to my project vMAP where I used functorch vmap function for vectorised training of multiple tiny MLPs (NeRF). Thanks a lot to @zou3519 the amazing developer team of functorch!
Thanks for sharing @kxhit and awesome to hear that you've found functorch to be helpful!