flops-counter.pytorch
flops-counter.pytorch copied to clipboard
Request to include FLOP count for Graph Convolutions
There is no standard module in torch.nn representing graph convolutions, while ptflops can account pytorch's modules only. You can also write a custom hook for your GCN implementation and pass it to ptflops.
@sovrasov Can you share some examples of writing a custom hook for GCN implementation?
Thank you!
Hi! Here is a brief example:
class MyModule(nn.Module):
def forward(self, x):
return x
def my_module_flops_counter_hook(module, input, output):
module.__flops__ += 0
model = MyModule()
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
print_per_layer_stat=True,
verbose=True,
custom_modules_hooks={MyModule: my_module_flops_counter_hook})
Instead of MyModule you could substitute your GCN implementation.