pytorch_ema
pytorch_ema copied to clipboard
How to use it in pytorch-lightning?
Hi, I would like to know how to use pytorch-ema for training in pytorh-lightning?
Hi @Devoe-97,
I think you'd want to implement a Lightning Callback
that runs the appropriate ema.update()
, ema.store(); ema.copy_to()
, and ema.restore()
operations at the right points in the training loop. Which callback methods that is I don't know, I'm not super familiar with Lightning. You might be able to take some inspiration from their StochasticWeightAveraging
callback, which is somewhat similar of an idea to the EMA implemented in this package. (That callback may also work for you depending on what you are trying to achieve.)
If you have any luck with this, please share your work if you are willing!
Hi @Devoe-97,
I think you'd want to implement a Lightning
Callback
that runs the appropriateema.update()
,ema.store(); ema.copy_to()
, andema.restore()
operations at the right points in the training loop. Which callback methods that is I don't know, I'm not super familiar with Lightning. You might be able to take some inspiration from theirStochasticWeightAveraging
callback, which is somewhat similar of an idea to the EMA implemented in this package. (That callback may also work for you depending on what you are trying to achieve.)If you have any luck with this, please share your work if you are willing!
Thanks!
Hi @Devoe-97,
I think you'd want to implement a Lightning
Callback
that runs the appropriateema.update()
,ema.store(); ema.copy_to()
, andema.restore()
operations at the right points in the training loop. Which callback methods that is I don't know, I'm not super familiar with Lightning. You might be able to take some inspiration from theirStochasticWeightAveraging
callback, which is somewhat similar of an idea to the EMA implemented in this package. (That callback may also work for you depending on what you are trying to achieve.)If you have any luck with this, please share your work if you are willing!
Hi, could you provide a demo in ddp mode? I encountered an OOM error while broadcasting in the validation step.
I second this original question. I think having this functionality working within PyTorch Lightning could greatly increase the visibility of this awesome repository!
A demo would be great!