[FIX] `torch.half` warning only applies to size estimation
Hello dear maintainer(s)!
Could I please have more explanation regarding the following?
https://github.com/TylerYep/torchinfo/blob/e67e74885ce103d7508b71c20e279aee1d9eb4ce/torchinfo/torchinfo.py#L398C1-L412C14
Is it still relevant today ? They logically pop up when using
summary(net, input_size=input_size, dtypes=[torch.half], device=torch.device("cpu"))
Thanks in advance!
The results won't be correct if input_size is used instead of input_data, so torchinfo errors instead of providing incorrect results.
Ok yes, I think this is exactly what I understood 👍 However, I was wondering if you could explain a little bit why the results would not be correct? I try to support sub 32 floats in my project for which I use torchinfo (link in case you are interested in seeing what your great work allows!), and I would like to be sure I'm not missing out on some subtleties of half and bfloat16!
Thanks!
I believe there was a bug report about this result being incorrect, but I don't quite recall anymore.
With input_size, torchinfo constructs an input for the model using the given size, but it likely won't account for the half sizes correctly.
PRs fixing this or adding tests that ensure this behavior works are welcome!
Thanks for your answer and sorry for the delay.
Thanks to your answer, I found this PR message, is it what you are referring to?
So, from what I understand, the "wrong result" mentioned in the warning only applies to the size estimation, is that correct? For my part, my project heavily uses your "architecture understanding". Is it safe to say that it is not impacted? If you confirm and you would like, I can propose a reformulation of the warning to make things clearer.
All the best. Élie
Correct. Feel free to submit a PR making the warning clearer.
Done. I do not really understand the second warning about cpu, so I left it as is.