[Question]Tutorial 5 (JAX) Max-pool branch for inception
Thank you for your great tutorials!
I have a question about codes in the cell 11 in Inception_ResNet_DenseNet.ipynb for JAX.
Max-pool branch looks like a 1x1 convolution branch, because the output of nn.max_pool() is not used.
x_max = nn.max_pool(x, (3, 3), strides=(2, 2)) x_max = nn.Conv(self.c_out["max"], kernel_size=(1, 1), kernel_init=googlenet_kernel_init, use_bias=False)(x)
I guess, here should be :
x_max = nn.max_pool(x, (3, 3), strides=(1, 1)) x_max = nn.Conv(self.c_out["max"], kernel_size=(1, 1), kernel_init=googlenet_kernel_init, use_bias=False)(x_max)
With strides = (2, 2), the feature size gets half of the original, so, the "strides" should be (1,1).
Thank you.
Hi, thanks for pointing that out! This is indeed a typo and should have used a stride of 1 and x_max as input to the conv. I'll leave this issue open, since we'll need to retrain the models for fixing this. Thanks again :)
Thank you for your reply!
I also retrained it and found very little difference between the models with and without pooling layer.