torchstat
torchstat copied to clipboard
stat doesn't generalize for non-image inputs
torchstat.stat
does not work with arbitrarily shaped input,
it only accepts inputs with shape `(channels, height, width)
I tried running it with an AutoEncoder for time series and it failed.
A simple example where it fails:
import torch
import torchstat
torchstat.stat(torch.nn.LSTM(10, 10, batch_first=True), (50, 10))
It is also not a very good idea to admit that the batch dimension is the first one, as this might not be the case for recurrent nets
Hi, @miguelvr Thank you for the detailed information. Initially, this tool was intended to compute computational costs on neural networks especially CNN. Therefore the desired input shape for stat is (channels, height, width).
But I think it is important for the stat to be able to accept any shape of input. So I'm considering the expansion of the stat for arbitrarily shaped input.
It is not clear for me what is the input shape for stat. Do you mean that, stat is designed for image specific CNN at the moment? (because channels, height, width are not the words for pytorch nor CNN in general...)
It is not clear for me what is the input shape for stat. Do you mean that, stat is designed for image specific CNN at the moment? (because channels, height, width are not the words for pytorch nor CNN in general...)
https://github.com/Swall0w/torchstat/blob/b52a3b06c2c54c2d09ade1a18cf6c4ca5dc27510/torchstat/main.py#L14
Thanks for such a cool job, it helps me a lot. And I wonder if torchstat now supports non-image input