Structure Difference between PyTorch ResNet and JAX resnet (at layer 4)
Hello Nicholas, while using pretrained RESNET(101) I am comparing the output size of RESNET model in PyTorch after layer no. 4 (rendering the output before the avg pooling there) after running it to an input batch size[1, 224, 224, 3] It was torch.Size ([1, 2048, 28, 28]). However, when I tried to render the output in your RESNET model JAX/FLAX (I have removed these 2 commented lines in RESNET function to get output before the avg pooling (layer4 equivalent to PyTorch)
def ResNet(
block_cls: ModuleDef,
*,
stage_sizes: Sequence[int],
n_classes: int,
hidden_sizes: Sequence[int] = (64, 128, 256, 512),
conv_cls: ModuleDef = nn.Conv,
norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9),
conv_block_cls: ModuleDef = ConvBlock,
stem_cls: ModuleDef = ResNetStem,
pool_fn: Callable = partial(nn.max_pool,
window_shape=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1))),
) -> Sequential:
conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
block_cls = partial(block_cls, conv_block_cls=conv_block_cls)
layers = [stem_cls(), pool_fn]
for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)):
for b in range(n_blocks):
strides = (1, 1) if i == 0 or b != 0 else (2, 2)
layers.append(block_cls(n_hidden=hsize, strides=strides))
#------------------------------------------------------------------------------
# layers.append(partial(jnp.mean, axis=(1, 2))) # global average pool
# layers.append(nn.Dense(n_classes))
#------------------------------------------------------------------------------
return Sequential(layers)
It has a different output shape (for the same size of inp_batch(1, 224, 224, 3)) :
RESNET100, variables = pretrained_resnet(101)
RESNET = RESNET100()
model_out=RESNET.apply(variables, jnp.ones((1, 224, 224, 3)) ,mutable=False)
print("pretrained resnet100 size:", jax.tree_map(lambda x: x.shape, model_out))
pretrained resnet100 size:--> (1, 7, 7, 2048) So, what's happened at this stage in ResNet layers structure? Kindly reply, if you have any explanation or recommendations.
Hi @sarahelsherif, thanks for raising this issue! Could you also paste in the PyTorch code that gives you torch.Size ([1, 2048, 28, 28]) for comparison?
Thank you @n2cholas , ok here is the PyTorch code:
class RESNET_Layer_4(nn.Module):
def __init__(self, backbone: nn.Module) -> None:
super(RESNET_Layer_4, self).__init__()
self.backbone = backbone
def forward(self, x: Tensor) -> Dict[str, Tensor]:
input_shape = x.shape[-2:]
backbone_resnet= self.backbone(x)
print("backbone resnet output shape",backbone_resnet["out"].shape)
return backbone_resnet
def resnet4(
backbone: ResNet,
) -> RESNET_Layer_4:
return_layers = {"layer4": "out"}
backbone = create_feature_extractor(backbone, return_layers)
return RESNET_Layer_4(backbone)
the input_batch shape is : torch.Size([1, 3, 224, 224])
pretrained_resnet= resnet101(pretrained=False , replace_stride_with_dilation=[False, True, True])
r4=resnet4(pretrained_resnet)
r4= r4.cuda()
out_resnet4=r4(inp_batch)
the output is : backbone resnet output shape torch.Size([1, 2048, 28, 28])
Hi @sarahelsherif, I wasn't able to directly use the code that you sent since I do not have create_feature_extractor. Instead, see this example of extracting the bacbone in both JAX and PyTorch here. As you can see they have the same output shape.
Does that help?
Hey @n2cholas , first of all thank you so much for help.
About create_feature_extractor , it is a utility from TorchVison create_feature_extractor ,which can be imported like this:
from torchvision.models.feature_extraction import create_feature_extractor
And thank you for your example, it helped. I know now why the output shape is different because of replacing strides and dilation in the pretrained resnet:
pretrained_resnet= resnet101(pretrained=False , replace_stride_with_dilation=[False, True, True])
So, my issue is solved now about different output shapes. On the other hand, I will be grateful , if you suggested a way to apply replacing stride with dilation in JAX.
This can definitely be supported, essentially we would need to apply the logic in _make_layer to the ResNetBottleneckBlock. I won't have the bandwidth to work on this for a few weeks, but would happy to review a PR if you decide to implement this.
Yes, sure ..thank you so much for help. And will update you, when I implement it