pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

🚀 Add FLOPs count to model summary

Open pietrolesci opened this issue 2 years ago • 8 comments

🚀 Feature

Add FLOPs count in model summary.

Motivation

Improvements in model development are increasingly evaluated using the FLOPs count (e.g., Training Compute-Optimal Large Language Models). However, there is no standardized way to compute FLOPs count, though many libraries exist (e.g., 1, 2). It would be great to add this functionality to the model summary in pytorch lightning.

cc @borda @kaushikb11 @awaelchli @rohitgr7 @akihironitta

pietrolesci avatar Apr 01 '22 13:04 pietrolesci

Prior issue: https://github.com/PyTorchLightning/pytorch-lightning/issues/3337

ananthsub avatar Apr 02 '22 05:04 ananthsub

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

stale[bot] avatar May 02 '22 10:05 stale[bot]

It could be implemented by adapting https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505 but PyTorch itself does not provide a solution upstream: https://github.com/pytorch/pytorch/issues/5013

carmocca avatar Aug 31 '22 17:08 carmocca

An implementation of the ideal flop counter is here:

https://github.com/pytorch-labs/torcheval/blob/main/torcheval/tools/module_summary.py

https://github.com/pytorch-labs/torcheval/blob/main/torcheval/tools/flops.py

ananthsub avatar Aug 31 '22 22:08 ananthsub

Thanks for the link @ananthsub. For anybody reading, this is how you would use it:

import torch
from torcheval.tools.module_summary import get_module_summary
from pytorch_lightning.demos.boring_classes import BoringModel

model = BoringModel()
summary = get_module_summary(model, torch.randn(2, 32))
print(summary.flops_forward, summary.flops_backward)

Still, we'll want to add FLOPs support to our ModelSummary class (torcheval's ModuleSummary looks quite similar :eyes:), so leaving this issue open.

carmocca avatar Aug 31 '22 23:08 carmocca

A FLOP counter was added to PyTorch: https://github.com/pytorch/pytorch/pull/95751

carmocca avatar Mar 03 '23 01:03 carmocca

#18848 added this small utility (to be released with 2.2)

from lightning.fabric.utilities import measure_flops

with torch.device("meta"):
    model = MyModel()
    x = torch.randn(2, 32)

model_fwd = lambda: model(x)
fwd_flops = measure_flops(model, model_fwd)

model_loss = lambda y: y.sum()
fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)

carmocca avatar Oct 30 '23 16:10 carmocca

How do you go with counting flops for billion-parameter models that get OOM when running the meta device?

championsnet avatar Apr 18 '24 16:04 championsnet