pytorch-beginner
pytorch-beginner copied to clipboard
issue with pytorch-beginner/05-Recurrent Neural Network/recurrent_network.py
-
need to change .data[0] => .item()
-
add model.train() at beginning of the loop
Only need to modify the training loop code, below is the fixed code worked for me :)
for epoch in range(num_epoches):
model.train()
print('epoch {}'.format(epoch + 1))
print('*' * 10)
running_loss = 0.0
running_acc = 0.0
for i, data in enumerate(train_loader, 1):
img, label = data
b, c, h, w = img.size()
assert c == 1, 'channel must be 1'
img = img.squeeze(1)
# img = img.view(b*h, w)
# img = torch.transpose(img, 1, 0)
# img = img.contiguous().view(w, b, -1)
if use_gpu:
img = Variable(img).cuda()
label = Variable(label).cuda()
else:
img = Variable(img)
label = Variable(label)
# 向前传播
out = model(img)
loss = criterion(out, label)
running_loss += loss.item() * label.size(0)
_, pred = torch.max(out, 1)
num_correct = (pred == label).sum()
running_acc += num_correct.item()
# 向后传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 300 == 0:
print('[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format(
epoch + 1, num_epoches, running_loss / (batch_size * i),
running_acc / (batch_size * i)))
print('Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(
epoch + 1, running_loss / (len(train_dataset)), running_acc / (len(
train_dataset))))
model.eval()
eval_loss = 0.
eval_acc = 0.
for data in test_loader:
img, label = data
b, c, h, w = img.size()
assert c == 1, 'channel must be 1'
img = img.squeeze(1)
# img = img.view(b*h, w)
# img = torch.transpose(img, 1, 0)
# img = img.contiguous().view(w, b, h)
if use_gpu:
img = Variable(img, volatile=True).cuda()
label = Variable(label, volatile=True).cuda()
else:
img = Variable(img, volatile=True)
label = Variable(label, volatile=True)
out = model(img)
loss = criterion(out, label)
eval_loss += loss.item() * label.size(0)
_, pred = torch.max(out, 1)
num_correct = (pred == label).sum()
eval_acc += num_correct.item()
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
test_dataset)), eval_acc / (len(test_dataset))))
print()