DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

Flops profiler does not account for overloaded operators

Open djwessel opened this issue 2 years ago • 2 comments
trafficstars

I was getting some weird results with my profiling results, and dug a bit deeper to find out that overloaded python operators (e.g. *, +) were not getting properly evaluated by the get_model_profile method. The following is a simple script to reproduce my results.

import torch
from deepspeed.profiling.flops_profiler.profiler import get_model_profile
from torch import nn


class SimpleAdder(nn.Module):
    def forward(self, x):
        return x + x


class SimpleAdder2(nn.Module):
    def forward(self, x):
        return torch.add(x, x)


flops, macs, params = get_model_profile(SimpleAdder(), input_shape=(1, 2, 10), print_profile=False)
print(flops)

flops, macs, params = get_model_profile(SimpleAdder2(), input_shape=(1, 2, 10), print_profile=False)
print(flops)

Output

0
20 # this is what I would expect

Due to the improved readability of using the overloaded operators, it is likely that many different models (especially those with residual blocks) suffer from a similar problem.

djwessel avatar Mar 23 '23 14:03 djwessel

Hi @djwessel, the flops profiler currently only captures torch.nn.functionals, thus the overloaded python operators are not included in the count.

cli99 avatar Jun 13 '23 16:06 cli99

Are there any plans to support the overloaded operators? Maybe with an approach similar to how the functionals are patched in the profiler?

djwessel avatar Jun 14 '23 07:06 djwessel