jax-resnet icon indicating copy to clipboard operation
jax-resnet copied to clipboard

Structure Difference between PyTorch ResNet and JAX resnet (at layer 4)

Open sarahelsherif opened this issue 4 years ago • 6 comments

Hello Nicholas, while using pretrained RESNET(101) I am comparing the output size of RESNET model in PyTorch after layer no. 4 (rendering the output before the avg pooling there) after running it to an input batch size[1, 224, 224, 3] It was torch.Size ([1, 2048, 28, 28]). However, when I tried to render the output in your RESNET model JAX/FLAX (I have removed these 2 commented lines in RESNET function to get output before the avg pooling (layer4 equivalent to PyTorch)

def ResNet(
    block_cls: ModuleDef,
    *,
    stage_sizes: Sequence[int],
    n_classes: int,
    hidden_sizes: Sequence[int] = (64, 128, 256, 512),
    conv_cls: ModuleDef = nn.Conv,
    norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9),
    conv_block_cls: ModuleDef = ConvBlock,
    stem_cls: ModuleDef = ResNetStem,
    pool_fn: Callable = partial(nn.max_pool,
                                window_shape=(3, 3),
                                strides=(2, 2),
                                padding=((1, 1), (1, 1))),
) -> Sequential:
    conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
    stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
    block_cls = partial(block_cls, conv_block_cls=conv_block_cls)
   
    layers = [stem_cls(), pool_fn] 

    for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)):
        for b in range(n_blocks):
            strides = (1, 1) if i == 0 or b != 0 else (2, 2)
            layers.append(block_cls(n_hidden=hsize, strides=strides))
 #------------------------------------------------------------------------------
 #  layers.append(partial(jnp.mean, axis=(1, 2)))  # global average pool
 # layers.append(nn.Dense(n_classes))
 #------------------------------------------------------------------------------
    return Sequential(layers)

It has a different output shape (for the same size of inp_batch(1, 224, 224, 3)) :

RESNET100, variables = pretrained_resnet(101)
RESNET = RESNET100()
model_out=RESNET.apply(variables, jnp.ones((1, 224, 224, 3)) ,mutable=False) 
print("pretrained resnet100 size:", jax.tree_map(lambda x: x.shape, model_out))

pretrained resnet100 size:--> (1, 7, 7, 2048) So, what's happened at this stage in ResNet layers structure? Kindly reply, if you have any explanation or recommendations.

sarahelsherif avatar Nov 25 '21 07:11 sarahelsherif

Hi @sarahelsherif, thanks for raising this issue! Could you also paste in the PyTorch code that gives you torch.Size ([1, 2048, 28, 28]) for comparison?

n2cholas avatar Nov 26 '21 16:11 n2cholas

Thank you @n2cholas , ok here is the PyTorch code:

class RESNET_Layer_4(nn.Module):
    
    def __init__(self, backbone: nn.Module) -> None:
        super(RESNET_Layer_4, self).__init__()
        self.backbone = backbone
    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        input_shape = x.shape[-2:]
        backbone_resnet= self.backbone(x)
        print("backbone resnet output shape",backbone_resnet["out"].shape)
        
        return backbone_resnet

def resnet4(
    backbone: ResNet,
) -> RESNET_Layer_4:
    return_layers = {"layer4": "out"}
    backbone = create_feature_extractor(backbone, return_layers)
    return RESNET_Layer_4(backbone)

the input_batch shape is : torch.Size([1, 3, 224, 224])

pretrained_resnet= resnet101(pretrained=False , replace_stride_with_dilation=[False, True, True])
r4=resnet4(pretrained_resnet)
r4= r4.cuda()
out_resnet4=r4(inp_batch)

the output is : backbone resnet output shape torch.Size([1, 2048, 28, 28])

sarahelsherif avatar Nov 27 '21 09:11 sarahelsherif

Hi @sarahelsherif, I wasn't able to directly use the code that you sent since I do not have create_feature_extractor. Instead, see this example of extracting the bacbone in both JAX and PyTorch here. As you can see they have the same output shape.

Does that help?

n2cholas avatar Nov 28 '21 03:11 n2cholas

Hey @n2cholas , first of all thank you so much for help. About create_feature_extractor , it is a utility from TorchVison create_feature_extractor ,which can be imported like this:

from torchvision.models.feature_extraction import create_feature_extractor

And thank you for your example, it helped. I know now why the output shape is different because of replacing strides and dilation in the pretrained resnet:

pretrained_resnet= resnet101(pretrained=False , replace_stride_with_dilation=[False, True, True])

So, my issue is solved now about different output shapes. On the other hand, I will be grateful , if you suggested a way to apply replacing stride with dilation in JAX.

sarahelsherif avatar Nov 29 '21 18:11 sarahelsherif

This can definitely be supported, essentially we would need to apply the logic in _make_layer to the ResNetBottleneckBlock. I won't have the bandwidth to work on this for a few weeks, but would happy to review a PR if you decide to implement this.

n2cholas avatar Dec 01 '21 06:12 n2cholas

Yes, sure ..thank you so much for help. And will update you, when I implement it

sarahelsherif avatar Dec 02 '21 07:12 sarahelsherif