djl icon indicating copy to clipboard operation
djl copied to clipboard

ResNetV1.java Network error

Open mymagicpower opened this issue 4 years ago • 6 comments

https://github.com/awslabs/djl/blob/master/model-zoo/src/main/java/ai/djl/basicmodelzoo/cv/classification/ResNetV1.java

   public static Block residualUnit(
            int numFilters,
            final Shape stride,
            final boolean dimMatch,
            boolean bottleneck,
            float batchNormMomentum) {
        SequentialBlock resUnit = new SequentialBlock();
        if (bottleneck) {
            resUnit.add(
                            Conv2d.builder()
                                    .setKernelShape(new Shape(1, 1))
                                    .setFilters(numFilters / 4)
                                    .optStride(stride)   <<<<<<<<<<   Should be : new Shape(1, 1)   <<<<<  I'm here!
                                    .optPadding(new Shape(0, 0))
                                    .optBias(true)
                                    .build())
                    .add(
                            BatchNorm.builder()
                                    .optEpsilon(1e-5f)
                                    .optMomentum(batchNormMomentum)
                                    .build())
                    .add(Activation::relu)
                    .add(
                            Conv2d.builder()
                                    .setKernelShape(new Shape(3, 3))
                                    .setFilters(numFilters / 4)
                                    .optStride(new Shape(1, 1))  <<<<<<<<<<  Should be : stride   <<<<<  I'm here!
                                    .optPadding(new Shape(1, 1))
                                    .optBias(false)
                                    .build())

Reference link: https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py

mymagicpower avatar Feb 07 '21 04:02 mymagicpower

@mymagicpower Thanks for finding that out, do you mind to open a PR for this?

lanking520 avatar Feb 16 '21 22:02 lanking520

FYI the Gluon implementation is using stride for first Conv2D and 1 for the second. Which is different from the Symbol implementation:

https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/gluon/model_zoo/vision/resnet.py#L90

roywei avatar Feb 17 '21 23:02 roywei

@mymagicpower Even in PyTorch implementation, the ResNet implementation uses stride for the first Conv2D and 1 for the second. Here is the link for your reference:

https://missinglink.ai/guides/pytorch/pytorch-resnet-building-training-scaling-residual-networks-pytorch/

ghost avatar Feb 18 '21 02:02 ghost

  1. For resnet18,resnet34: The ResNet implementation uses stride for the first Conv2D(the kernel size is 3) and 1 for the second (the kernel size is 3).
  2. For resnet50, resnet101, resnet152: The ResNet implementation uses 1 for the first Conv2D(kernel size is 1) and stride for the second (the kernel size is 3).

mymagicpower avatar Feb 20 '21 03:02 mymagicpower

image

mymagicpower avatar Feb 20 '21 03:02 mymagicpower

@mymagicpower From the looks of it, I think you are suggesting the problem in the Bottleneck section. Is this what you're suggesting ?

aksrajvanshi avatar Feb 22 '21 18:02 aksrajvanshi