model-zoo icon indicating copy to clipboard operation
model-zoo copied to clipboard

MethodError using the ResNet transfer learning tutorial

Open rgreilly opened this issue 3 years ago • 5 comments

Julia: 1.8.1 on Mac M1 Flux: 0.13.4 MetalHead: 0.7.3 Pluto: 0.19.11

I've run into a problem running the transfer learning tutorial in the Model Zoo.

This is a grab from the model definition in Pluto: Screenshot 2022-09-19 at 10 27 00

This is the error I'm getting: Screenshot 2022-09-19 at 10 27 28

I've tried various solutions: clearing the cache, reinstalling Juia, updating Flux, MetalHead, and Pluto, but all to no avail. Any advice would be appreciated.

Ronan

rgreilly avatar Sep 19 '22 10:09 rgreilly

It looks like resnet[1:end-2] is producing an empty Chain in your model. That's why taking the gradient throws an error.

Can you post the cells that generate resnet?

darsnack avatar Sep 20 '22 16:09 darsnack

Hi,

I used the resnet version in Metalhead:

Screenshot 2022-09-22 at 00 21 28

However, the structure of the network is not the flat one assumed by the resnet[1:end-2] expression. As you suspected, this slice gives an empty chain. I'm still trying to work out how to extract all but the last two layers of the network for transfer learning.

Ronan

Chain(
  Chain([
    Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
    BatchNorm(64, relu),                # 128 parameters, plus 128
    MaxPool((3, 3), pad=1, stride=2),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 64 => 64, bias=false),  # 4_096 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((1, 1), 64 => 256, bias=false),  # 16_384 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
      Chain([
        Conv((1, 1), 64 => 256, bias=false),  # 16_384 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ]),
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 256 => 64, bias=false),  # 16_384 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((1, 1), 64 => 256, bias=false),  # 16_384 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 256 => 64, bias=false),  # 16_384 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((1, 1), 64 => 256, bias=false),  # 16_384 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 256 => 128, bias=false),  # 32_768 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((3, 3), 128 => 128, pad=1, stride=2, bias=false),  # 147_456 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((1, 1), 128 => 512, bias=false),  # 65_536 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      Chain([
        Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ]),
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 512 => 128, bias=false),  # 65_536 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((1, 1), 128 => 512, bias=false),  # 65_536 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 512 => 128, bias=false),  # 65_536 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((1, 1), 128 => 512, bias=false),  # 65_536 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 512 => 128, bias=false),  # 65_536 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((1, 1), 128 => 512, bias=false),  # 65_536 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 512 => 256, bias=false),  # 131_072 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, stride=2, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      Chain([
        Conv((1, 1), 512 => 1024, stride=2, bias=false),  # 524_288 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ]),
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 256, bias=false),  # 262_144 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 256, bias=false),  # 262_144 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 256, bias=false),  # 262_144 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 256, bias=false),  # 262_144 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 256, bias=false),  # 262_144 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 512, bias=false),  # 524_288 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((3, 3), 512 => 512, pad=1, stride=2, bias=false),  # 2_359_296 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((1, 1), 512 => 2048, bias=false),  # 1_048_576 parameters
        BatchNorm(2048),                # 4_096 parameters, plus 4_096
      ),
      Chain([
        Conv((1, 1), 1024 => 2048, stride=2, bias=false),  # 2_097_152 parameters
        BatchNorm(2048),                # 4_096 parameters, plus 4_096
      ]),
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 2048 => 512, bias=false),  # 1_048_576 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((1, 1), 512 => 2048, bias=false),  # 1_048_576 parameters
        BatchNorm(2048),                # 4_096 parameters, plus 4_096
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 2048 => 512, bias=false),  # 1_048_576 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((1, 1), 512 => 2048, bias=false),  # 1_048_576 parameters
        BatchNorm(2048),                # 4_096 parameters, plus 4_096
      ),
      identity,
    ),
  ]),
  Chain(
    AdaptiveMeanPool((1, 1)),
    MLUtils.flatten,
    Dense(2048 => 1000),                # 2_049_000 parameters
  ),
)         # Total: 161 trainable arrays, 25_557_032 parameters,
          # plus 106 non-trainable, 53_120 parameters, summarysize 97.732 MiB.

rgreilly avatar Sep 21 '22 23:09 rgreilly

Ah looks like the tutorial is outdated. Metalhead splits the model into a 2-element Chain of sub-Chains. The first element is the backbone (that you want for transfer learning), and the second element is the classifier. The old Metalhead just had one flat chain.

The "official" way to get what you want is Metalhead.backbone(ResNet()).

darsnack avatar Sep 22 '22 01:09 darsnack

Thanks!

Following your suggestion, I've modified the model structure as follows, but am still having trouble adding to the backbone.

Screenshot 2022-09-22 at 17 52 55

This gives the following structure:

Chain(
  Chain([
    Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
    BatchNorm(64, relu),                # 128 parameters, plus 128
    MaxPool((3, 3), pad=1, stride=2),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 64 => 64, bias=false),  # 4_096 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((1, 1), 64 => 256, bias=false),  # 16_384 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
      Chain([
        Conv((1, 1), 64 => 256, bias=false),  # 16_384 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ]),
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 256 => 64, bias=false),  # 16_384 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((1, 1), 64 => 256, bias=false),  # 16_384 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 256 => 64, bias=false),  # 16_384 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        Conv((1, 1), 64 => 256, bias=false),  # 16_384 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 256 => 128, bias=false),  # 32_768 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((3, 3), 128 => 128, pad=1, stride=2, bias=false),  # 147_456 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((1, 1), 128 => 512, bias=false),  # 65_536 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      Chain([
        Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ]),
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 512 => 128, bias=false),  # 65_536 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((1, 1), 128 => 512, bias=false),  # 65_536 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 512 => 128, bias=false),  # 65_536 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((1, 1), 128 => 512, bias=false),  # 65_536 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 512 => 128, bias=false),  # 65_536 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128, relu),           # 256 parameters, plus 256
        Conv((1, 1), 128 => 512, bias=false),  # 65_536 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 512 => 256, bias=false),  # 131_072 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, stride=2, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      Chain([
        Conv((1, 1), 512 => 1024, stride=2, bias=false),  # 524_288 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ]),
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 256, bias=false),  # 262_144 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 256, bias=false),  # 262_144 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 256, bias=false),  # 262_144 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 256, bias=false),  # 262_144 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 256, bias=false),  # 262_144 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256, relu),           # 512 parameters, plus 512
        Conv((1, 1), 256 => 1024, bias=false),  # 262_144 parameters
        BatchNorm(1024),                # 2_048 parameters, plus 2_048
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 1024 => 512, bias=false),  # 524_288 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((3, 3), 512 => 512, pad=1, stride=2, bias=false),  # 2_359_296 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((1, 1), 512 => 2048, bias=false),  # 1_048_576 parameters
        BatchNorm(2048),                # 4_096 parameters, plus 4_096
      ),
      Chain([
        Conv((1, 1), 1024 => 2048, stride=2, bias=false),  # 2_097_152 parameters
        BatchNorm(2048),                # 4_096 parameters, plus 4_096
      ]),
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 2048 => 512, bias=false),  # 1_048_576 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((1, 1), 512 => 2048, bias=false),  # 1_048_576 parameters
        BatchNorm(2048),                # 4_096 parameters, plus 4_096
      ),
      identity,
    ),
    Parallel(
      Metalhead.addrelu,
      Chain(
        Conv((1, 1), 2048 => 512, bias=false),  # 1_048_576 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512, relu),           # 1_024 parameters, plus 1_024
        Conv((1, 1), 512 => 2048, bias=false),  # 1_048_576 parameters
        BatchNorm(2048),                # 4_096 parameters, plus 4_096
      ),
      identity,
    ),
  ]),
  Chain(
    Dense(2048 => 1000),                # 2_049_000 parameters
    Dense(1000 => 256),                 # 256_256 parameters
    Dense(256 => 2),                    # 514 parameters
  ),
) 

However, this throws a dimension mismatch error, but I can't work out where the B dimensions come from in the error message:

Screenshot 2022-09-22 at 17 56 00

Ronan

rgreilly avatar Sep 22 '22 17:09 rgreilly

You need a pooling layer or other reduction (+ flatten) like the default classification head uses in order to eliminate the spatial dimensions from the backbone output. Running the forward pass of the network should give a clearer error, and Flux.outputsize is always useful when debugging model size mismatches.

ToucheSir avatar Sep 22 '22 20:09 ToucheSir

Great! That did the trick.

Many thanks to everybody for all the help on this.

Ronan

rgreilly avatar Sep 23 '22 10:09 rgreilly