torchstat
torchstat copied to clipboard
analysis nn.Linear() is error
I think when bias is True, Madd shouls be 72.
`class StatModel(nn.Module): def init(self): super(StatModel, self).init() self.layers_bias_false = nn.Linear(12, 3, bias=False) self.layers_bias_true = nn.Linear(12, 3, bias=True)
def forward(self, x):
x = x.reshape(x.shape[0], -1)
x_false = self.layers_bias_false(x)
x_ture = self.layers_bias_true(x)
return x_false, x_ture
if name == 'main': print(stat(StatModel(), (3, 2, 2)))`