temporal-shift-module
temporal-shift-module copied to clipboard
Online model always result in idx to '21'
I add some code in online_demo/main.py as below:
def main(model_path, d_path):
net = MobileNetV2(n_class=27)
net.load_state_dict(torch.load(model_path))
transform = get_transform()
shift_buffer = [torch.zeros([1, 3, 56, 56]),
torch.zeros([1, 4, 28, 28]),
torch.zeros([1, 4, 28, 28]),
torch.zeros([1, 8, 14, 14]),
torch.zeros([1, 8, 14, 14]),
torch.zeros([1, 8, 14, 14]),
torch.zeros([1, 12, 14, 14]),
torch.zeros([1, 12, 14, 14]),
torch.zeros([1, 20, 7, 7]),
torch.zeros([1, 20, 7, 7])]
fnames = os.listdir(d_path)
for fname in sorted(fnames):
fpath = os.path.join(d_path, fname)
image = cv2.imread(fpath)
image = transform([Image.fromarray(image).convert('RGB')])
image = torch.autograd.Variable(image.view(1, 3, image.size(1), image.size(2)))
p, *shift_buffer = net(image, *shift_buffer)
idx = torch.argmax(p.squeeze()).item()
code.interact(local = locals())
I test in jester data, but all the pred idx is 21. Could help me solve the problem? wish...
Try to add net.eval() to switch to eval mode
Try to add
net.eval()to switch to eval mode
oh my god, I forget it ! I will try it!
Hi, did you resolve the issue with the above suggestion? Thanks!