torchstat
torchstat copied to clipboard
about analysis of single layer
if I want to get the calculation analysis of a fc layer, how can I use this tool?
import torch.nn as nn
import torch
from torchstat import stat
class Net(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.fc = nn.Linear(144,4096)
def forward(self, x):
x = self.fc(x)
return x
net=Net()
stat(net,(1,64,144))
I wrote this code, but assert len(inp.size()) == 2 and len(out.size()) == 2 cannot pass, how can i change the code?