[FR] Support Automatic Mixed Precision training
Issue Description
Better support for mixed precision training would be extremely helpful, at least for SVI. I can manually cast data into float16 or bfloat16 but I am unable to leverage PyTorch's automatic mixed precision training. This is because it requires the use of the GradScaler class during the optimization loop to properly scale gradients in a mixed-precision-aware manner. See the documentation for more info: https://pytorch.org/docs/stable/amp.html
It would be nice to have support for using this class within pyro optimizers to allow for amp support.
@fritzo I might be willing to try to tackle this, do you have any opinions on how to expose the functionality to the end user?
Hi @austinv11, Thanks for offering. I'd guess there are a few ways we could support AMP in Pyro:
- Use Pyro's ELBOModule to construct a differentiable loss function as in the lightning tutorial, then do standard PyTorch training with AMP. I think Pyro's code already supports this, we'd just need improved documentation and maybe an example:
- Add a docstring to
ELBOModuleexplaining how it is created and why it is useful. - Add
ELBO.__call__method to sphinx's :special-members: list here - Add an examples/svi_amp.py similar to examples/svi_lightning.py
- Do something similar, but with the Trace_ELBO.differentiable_loss() method.
- Add more native AMP support to
pyro.optim's wrapper class. This seems intricate and more difficult to maintain though.
Would you be interested in getting (1) or (2) working for yourself then contributing docs to show how you did it? We're happy to answer any questions about Pyro, but I think you know more about AMP than us 🙂
It looks like I might need to try option 3 since AMP-aware gradient scaling requires access to the optimizer's step() function.
I could try making it a boolean flag for PyroOptim to enable AMP. Additionally, once that is enabled the user would need to manually use Pytorch's autocast context manager within their models.
But I could see most users wanting to just activate AMP for their entire model rather than just specific portions of code. Do you think it might be worth adding a new ELBO function that autocasts the entire model for the user?
Let me try again to persuade you towards options (1) or (2) 😄, admitting I don't know your details or how AMP works.
Back in the early days of Pyro we decided to wrap PyTorch's optimizer classes so we could have more control over dynamically created parameters. In practice this made Pyro's optimization idioms incompatible with other frameworks build on top of PyTorch, e.g. lightning, horovod, AMP, new higher-order optimizers. To work around this incompatibility we've since added ways to compute differentiable losses in Pyro so that optimization can be done entirely using torch idioms, without ever using pyro.optim.
For example instead of the original pyro-idiomatic optimization
def model(args):
...
guide = AutoNormal(model)
elbo = Trace_ELBO()
optim = pyro.optim.Adam(...) # <---- pyro idioms
svi = SVI(model, guide, optim, elbo)
for step in range(...):
svi.step(args)
you can use torch-idiomatic optimizers
class Model(PyroModule):
def forward(args):
...
model = Model()
guide = AutoNormal(model)
elbo = Trace_ELBO()
loss_fn = elbo(model, guide)
optim = torch.optim.Adam(elbo.parameters(), ...) # <---- torch idioms
for step in range(...):
optimizer.zero_grad()
loss = loss_fn(args)
loss.backward()
optimizer.step() # <---- Can we use AMP here?
What I'm hoping is that by switching to torch-native optimizers it will be easy/trivial to support AMP.
That said, we'd still be open to adding AMP support to pyro.optim if you can find a simple maintainable way to do so 🙂.
Ah, I see what you mean. Am I correct in understanding that this wouldn't be compatible with the SVI trainer and would require using PyroModules then?
That is also incompatible with models/guides that dynamically create parameters during training, if I understand correctly.
@austinv11 @ilia-kats correct.