pytorch-lightning
pytorch-lightning copied to clipboard
🚀 Add FLOPs count to model summary
🚀 Feature
Add FLOPs count in model summary.
Motivation
Improvements in model development are increasingly evaluated using the FLOPs count (e.g., Training Compute-Optimal Large Language Models). However, there is no standardized way to compute FLOPs count, though many libraries exist (e.g., 1, 2). It would be great to add this functionality to the model summary in pytorch lightning.
cc @borda @kaushikb11 @awaelchli @rohitgr7 @akihironitta
Prior issue: https://github.com/PyTorchLightning/pytorch-lightning/issues/3337
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
It could be implemented by adapting https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505 but PyTorch itself does not provide a solution upstream: https://github.com/pytorch/pytorch/issues/5013
An implementation of the ideal flop counter is here:
https://github.com/pytorch-labs/torcheval/blob/main/torcheval/tools/module_summary.py
https://github.com/pytorch-labs/torcheval/blob/main/torcheval/tools/flops.py
Thanks for the link @ananthsub. For anybody reading, this is how you would use it:
import torch
from torcheval.tools.module_summary import get_module_summary
from pytorch_lightning.demos.boring_classes import BoringModel
model = BoringModel()
summary = get_module_summary(model, torch.randn(2, 32))
print(summary.flops_forward, summary.flops_backward)
Still, we'll want to add FLOPs support to our ModelSummary
class (torcheval's ModuleSummary
looks quite similar :eyes:), so leaving this issue open.
A FLOP counter was added to PyTorch: https://github.com/pytorch/pytorch/pull/95751
#18848 added this small utility (to be released with 2.2)
from lightning.fabric.utilities import measure_flops
with torch.device("meta"):
model = MyModel()
x = torch.randn(2, 32)
model_fwd = lambda: model(x)
fwd_flops = measure_flops(model, model_fwd)
model_loss = lambda y: y.sum()
fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
How do you go with counting flops for billion-parameter models that get OOM when running the meta device?