flops-counter.pytorch icon indicating copy to clipboard operation
flops-counter.pytorch copied to clipboard

Request to include FLOP count for Graph Convolutions

Open pranavgundewar opened this issue 4 years ago • 3 comments

pranavgundewar avatar Jun 15 '21 22:06 pranavgundewar

There is no standard module in torch.nn representing graph convolutions, while ptflops can account pytorch's modules only. You can also write a custom hook for your GCN implementation and pass it to ptflops.

sovrasov avatar Jun 24 '21 06:06 sovrasov

@sovrasov Can you share some examples of writing a custom hook for GCN implementation?

Thank you!

pranavgundewar avatar Aug 09 '22 22:08 pranavgundewar

Hi! Here is a brief example:

class MyModule(nn.Module):
    def forward(self, x):
        return x
def my_module_flops_counter_hook(module, input, output):
    module.__flops__ += 0

model = MyModule()
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
                                           print_per_layer_stat=True,
                                           verbose=True,
                                           custom_modules_hooks={MyModule: my_module_flops_counter_hook})

Instead of MyModule you could substitute your GCN implementation.

sovrasov avatar Aug 11 '22 16:08 sovrasov