tps_stn_pytorch icon indicating copy to clipboard operation
tps_stn_pytorch copied to clipboard

run python mnist_visualize.py --model unbounded_stn --angle 90 --grid_size 4 error

Open KakaVlasic opened this issue 6 years ago • 3 comments

hello! when I run python mnist_visualize.py --model unbounded_stn --angle 90 --grid_size 4, it appears following:

create model with STN Traceback (most recent call last): File "mnist_visualize.py", line 47, in data_list = target2data_list[target] KeyError: tensor(7)

any help will be appreciate!

KakaVlasic avatar Feb 19 '19 05:02 KakaVlasic

I have Same problem

w11m avatar Feb 19 '19 10:02 w11m

target2data_list = [list() for i in range(10)]

jackweiwang avatar Nov 27 '19 07:11 jackweiwang

I think index shouldn't be a tensor, it should be int type, like this :

data_list = target2data_list[int(target)]

MingqiangNing avatar Jul 15 '21 04:07 MingqiangNing