torchstat
torchstat copied to clipboard
is cuda() type data are not added to the isinstance list?
i got this error in the running time: File "col_cub_gan_gn.py", line 513, in main train(args) File "col_cub_gan_gn.py", line 232, in train stat(clss,temp) File "/home/xujia/anaconda3/envs/xujia-py37/lib/python3.7/site-packages/torchstat/statistics.py", line 70, in stat ms = ModelStat(model, input_size, query_granularity) File "/home/xujia/anaconda3/envs/xujia-py37/lib/python3.7/site-packages/torchstat/statistics.py", line 51, in init assert isinstance(input_size, (tuple, list)) and len(input_size) == 3
and here is how i define the variable "temp" temp = (batch_x.shape[1],batch_x.shape[2],batch_x.shape[3]) # print(temp.type) temp =torch.tensor(temp).to(device) stat(clss,temp)
hope you can help me, thanks!
You only need to provide shape, that is, "temp = (3, 224, 224)", not "temp= torch.tensor (temp).to(device)"