machine-learning-book icon indicating copy to clipboard operation
machine-learning-book copied to clipboard

Chapter 14, MNIST test set plot

Open acmoudleysa opened this issue 2 years ago • 1 comments

Page 482 The code given is not working as expected. I think it should be mnist_test_dataset.data.

fig = plt.figure(figsize=(12, 4))
for i in range(12):
    ax = fig.add_subplot(2, 6, i+1)
    ax.set_xticks([]); ax.set_yticks([])
    img = mnist_test_dataset[i][0][0, :, :]
    pred = model(img.unsqueeze(0).unsqueeze(1))
    y_pred = torch.argmax(pred)
    ax.imshow(img, cmap='gray_r')
    ax.text(0.9, 0.1, y_pred.item(), 
            size=15, color='blue',
            horizontalalignment='center',
            verticalalignment='center', 
            transform=ax.transAxes)
    
    
#plt.savefig('figures/14_14.png')
plt.show()

I changed some stuffs and it works.

fig = plt.figure(figsize = (12,4))
for i in range(12):
    ax = fig.add_subplot(2,6,i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    img = mnist_test_dataset.data[i,:,:]   # 28*28 
    pred = model(img.unsqueeze(0).unsqueeze(0)/1.) # Adding two dimensions (unsqueeze(0).unsqueeze(0) \
 # or unsqueeze(0).unsqueeze(1)) and changing the datatype to float32. 
    y_pred = torch.argmax(pred)
    ax.imshow(img, cmap = 'gray_r')
    ax.text(0.9,0.1,y_pred.item(), size = 15, color = 'blue', horizontalalignment = 'center', verticalalignment = 'center', transform = ax.transAxes)
plt.show()

acmoudleysa avatar Sep 30 '23 06:09 acmoudleysa

Thanks for the note. I just tried it and it both works for me. But yes, you could use .data instead. I.e., instead of

mnist_test_dataset[i][0][0, :, :]

you could use

img = mnist_test_dataset.data[i].float()

rasbt avatar Dec 28 '23 08:12 rasbt

I added this as an alternative code line to Ch 14 in case others have the same issue.

rasbt avatar Apr 30 '24 12:04 rasbt