machine-learning-book
machine-learning-book copied to clipboard
Chapter 14, MNIST test set plot
Page 482
The code given is not working as expected. I think it should be mnist_test_dataset.data.
fig = plt.figure(figsize=(12, 4))
for i in range(12):
ax = fig.add_subplot(2, 6, i+1)
ax.set_xticks([]); ax.set_yticks([])
img = mnist_test_dataset[i][0][0, :, :]
pred = model(img.unsqueeze(0).unsqueeze(1))
y_pred = torch.argmax(pred)
ax.imshow(img, cmap='gray_r')
ax.text(0.9, 0.1, y_pred.item(),
size=15, color='blue',
horizontalalignment='center',
verticalalignment='center',
transform=ax.transAxes)
#plt.savefig('figures/14_14.png')
plt.show()
I changed some stuffs and it works.
fig = plt.figure(figsize = (12,4))
for i in range(12):
ax = fig.add_subplot(2,6,i+1)
ax.set_xticks([])
ax.set_yticks([])
img = mnist_test_dataset.data[i,:,:] # 28*28
pred = model(img.unsqueeze(0).unsqueeze(0)/1.) # Adding two dimensions (unsqueeze(0).unsqueeze(0) \
# or unsqueeze(0).unsqueeze(1)) and changing the datatype to float32.
y_pred = torch.argmax(pred)
ax.imshow(img, cmap = 'gray_r')
ax.text(0.9,0.1,y_pred.item(), size = 15, color = 'blue', horizontalalignment = 'center', verticalalignment = 'center', transform = ax.transAxes)
plt.show()
Thanks for the note. I just tried it and it both works for me. But yes, you could use .data instead. I.e., instead of
mnist_test_dataset[i][0][0, :, :]
you could use
img = mnist_test_dataset.data[i].float()
I added this as an alternative code line to Ch 14 in case others have the same issue.