pytorch_modelsize
pytorch_modelsize copied to clipboard
Estimate size with arbitrary functions in `.forward()` using graph history
The original implementation of this tool relies on accessing tensor operations through model.modules()
. This is simple, but cannot account for arbitrary dimensionality changes in the .forward()
method. For instance, it's common to perform a tensor.view(...)
or torch.nn.functional.max_pool()
in .forward()
.
We could alternatively extract the operation history from the graph of functions recorded by autograd
.
See https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py for an example where the graph is extracted.