sparseml
sparseml copied to clipboard
FLOPs (floating-point operations) module for ConvNets
This code updates an existing but unused model analyzer (AnalyzerModule) object that computes forward FLOPs, parameters, prunable parameters, and zeroed parameters model-wide. Note that this is somewhat redundant with the ModelPruningAnalyzer, which is invoked by pruning modifiers. However, there are some key differences:
- The AnalyzerModule counts FLOPs as well as sparsities.
- The AnalyzerModule, if enabled, also counts cumulative FLOPs, and works at the granularity of the model + its submodules. This means that it must be enabled from the very first step until the last. This is unlike ModelPruningAnalyzer, which is at the granularity of a specific pruner, and is active only when that pruner is active, and only on the parameters on which the pruner operates.
- For the moment, the AnalyzerModule only supports module types that appear in ConvNets, and should not be used with transformers.
However, as a future TODO, the AnalyzerModule can easily subsume the operations of the ModelPruningAnalyzer, and probably it's better-positioned in terms of being training-wide rather than pruner-specific, so it might be worth it to remove the latter in some future iteration.
Since the AnalyzerModule was not used before this change, it was updated to better fit the needs of the benchmarking project. As such, it no longer tracks things like param shapes, but only those metrics that are useful, like forward FLOPs and sparsities. The "total FLOPs" field is now used to track cumulative FLOPs, which were not tracked previously. The analyzer does not track backward FLOPs, because the standard implementation is to approximate Backward FLOPs ~= 2*Forward FLOPs; this calculation is not part of the code but would have to be performed manually, possibly with adjustments for pruners like Top-KAST, where the backward sparsity doesn't match the forward. The code is completely agnostic to this sort of nuance, which seems easier.
Note that the definitions of FLOP calculations have been changed to match those of AC/DC and RigL. In the vast majority of cases this did not affect the calculation results (which can be seen in the tests).
Testing: updated/expanded the existing tests + ran a few epochs to compare agreement with numbers published for AC/DC and RigL.