When I use torchstat to calculate Params and Flops for Vit, some errors happened.
After debug I found some ops not supported.
class Attention(nn.Module):
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
Here chunk op can not be supported.
Which tool shuold I use to calculate Params and Flops? Thanks!
code:
from torchstat import stat
mbvit_xs = MobileViT(
image_size = (256, 256),
dims = [96, 120, 144],
channels = [16, 32, 48, 48, 64, 64, 64, 32, 16, 2],
num_classes = 1000
)
stat(mbvit_xs, (3, 256, 256))
Error print:
File "mobile_vit.py", line 410, in
stat(mbvit_xs, (3, 256, 256))
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/statistics.py", line 71, in stat
ms.show_report()
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/statistics.py", line 64, in show_report
collected_nodes = self._analyze_model()
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/statistics.py", line 57, in _analyze_model
model_hook = ModelHook(self._model, self._input_size)
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/model_hook.py", line 24, in init
self._model(x)
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "mobile_vit.py", line 383, in forward
h5 = self.trunk12(h4)
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "mobile_vit.py", line 186, in forward
x = self.transformer(x)
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "mobile_vit.py", line 111, in forward
x = attn(x) + x
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "mobile_vit.py", line 32, in forward
return self.fn(self.norm(x), **kwargs)
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "mobile_vit.py", line 71, in forward
qkv = self.to_qkv(x).chunk(3, dim=-1)
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/model_hook.py", line 76, in wrap_call
madd = compute_madd(module, input[0], output)
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/compute_madd.py", line 156, in compute_madd
return compute_Linear_madd(module, inp, out)
File "/home2/zmy/anaconda3/envs/py38/lib/python3.8/site-packages/torchstat/compute_madd.py", line 117, in compute_Linear_madd
assert len(inp.size()) == 2 and len(out.size()) == 2
AssertionError
I have encountered the same problem. Have you found a solution to this problem?
sorry, I can not find a good solution for this problem so far.
What version of TorchStat are you using? - Clerkie (https://clerkie.co/)
cc: @lucidrains trying to gather some context for future debugging - hope that's okay!
@myzhuang you can see this:https://github.com/Swall0w/torchstat/issues/18 the code runs succefully,but the result of MACs seems don‘t right ,campare with other methods