pytorch-deep-learning icon indicating copy to clipboard operation
pytorch-deep-learning copied to clipboard

Convolutional - Section 7 model issue

Open SquareGraph opened this issue 3 years ago • 3 comments

Hey, weird stuff's happening. In the section 7.2 when we're setting up CNN and have to figure out what shapes of the input features should be. Below you can take a look for a print screen from collab. I first encountered this mismatch while writing my own code. But next I tried Yours code and the same problem - it looks like somehow the inputs aren't flatten so instead of 1x490 we have 10x49.

Zrzut ekranu 2022-09-2 o 15 23 49

Anyone here with similar problems?

SquareGraph avatar Sep 02 '22 13:09 SquareGraph

Looks like the issue is around nn.Flatten layer. Passing in parameters nn.Flatten(0,2) transforms shape of x into (490), but it's still not 1x490.

SquareGraph avatar Sep 02 '22 13:09 SquareGraph

Hey @SquareGraph,

I can't seem to reproduce this error on my end.

I just ran notebook 03 end-to-end on Google Colab and didn't get any errors.

You could try running the notebook yourself and see what happens - there may be an issue in the shape creations of your code?

Did you work out a fix?

For reference my Colab instance is using the following versions:

PyTorch version: 1.12.1+cu113
torchvision version: 0.13.1+cu113

mrdbourke avatar Sep 07 '22 01:09 mrdbourke

Hey,

PyTorch: 1.12.1+cu113 Torchvision: 0.13.1+cu113.

Any time I ran mine version or you version of the script in the colab with nn.Flatten() I encounter the same runtime error. But the final solution is to pass to Flatten two arguments (nn.Flatten(1,3)), according the documentation. https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html Like with those arguments we are keeping the first dimension of a shape (nr of samples), but flattening channels, height and width. And so we arrive with a shape of (1,490) and we can @ it to the shape of (490,10) easily.

Think it's resolved/working solution. But dunno yet why did this happen.

SquareGraph avatar Sep 07 '22 09:09 SquareGraph

I had the same issue when I tried passing an image of shape 1, 28, 28, But when I passed an image with shape 1,1,28,28 using unsqueeze the issue was solved

sankarvinayak avatar Nov 14 '22 14:11 sankarvinayak

@SquareGraph glad to hear you found a fix!

I'm not 100% sure what went wrong either.

Edit: closing this now, feel free to reopen if needed.

mrdbourke avatar Nov 15 '22 06:11 mrdbourke