DeepSpeed
DeepSpeed copied to clipboard
Flops profiler does not account for overloaded operators
I was getting some weird results with my profiling results, and dug a bit deeper to find out that overloaded python operators (e.g. *, +) were not getting properly evaluated by the get_model_profile method. The following is a simple script to reproduce my results.
import torch
from deepspeed.profiling.flops_profiler.profiler import get_model_profile
from torch import nn
class SimpleAdder(nn.Module):
def forward(self, x):
return x + x
class SimpleAdder2(nn.Module):
def forward(self, x):
return torch.add(x, x)
flops, macs, params = get_model_profile(SimpleAdder(), input_shape=(1, 2, 10), print_profile=False)
print(flops)
flops, macs, params = get_model_profile(SimpleAdder2(), input_shape=(1, 2, 10), print_profile=False)
print(flops)
Output
0
20 # this is what I would expect
Due to the improved readability of using the overloaded operators, it is likely that many different models (especially those with residual blocks) suffer from a similar problem.
Hi @djwessel, the flops profiler currently only captures torch.nn.functionals, thus the overloaded python operators are not included in the count.
Are there any plans to support the overloaded operators? Maybe with an approach similar to how the functionals are patched in the profiler?