pytorch-image-models
pytorch-image-models copied to clipboard
Use in-place operations for EMA
Hi,
I noticed that the EMA used here is pretty slow, since it does not use in-place operations. Using in-place ops results in a ~50% faster EMA, however, it does not work with type-promotion.
One workaround is to check whether a tensor from the model is floating point or not, with Tensor.is_floating_point().
I don't think there is any point in doing EMA on int tensors such as num_batches_tracked in BN.
Cheers!
@jeromerony thanks for the PR, noticed you commenting on the twitter thread re the foreach, any desire to throw that in as an option too? probably should be a bool flag, not sure for_each will work on PyTorch XLA (for TPU)
Sure! My initial PR tried to keep the same logic and code structure. However, adding the foreach functions requires some changes:
- How to handle buffers, e.g. BN
running_meanandrunning_var? Should they be EMAed or simply copied? If EMAed, how should the int buffers (e.g. BNnum_batches_tracked) be handled? - How important is the
update,setand_updatestructure? Thesetmethod does not seem to be used anywhere in this repo (even though it might be somewhere else). The_updatemethod would become obsolete with the use of foreach. - I am unsure about the limitations of XLA and do not have access to that hardware to test.
Quick testing on a first solution gives me a 2 to 3x improvement without foreach, and ~5x improvement with foreach.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
Thanks for the additions, I've been trying to push a number of things over the finish lately recently so haven't had a change to dig in to this, but will try to take a closer look, test, etc soon...