MethodError using the ResNet transfer learning tutorial
Julia: 1.8.1 on Mac M1 Flux: 0.13.4 MetalHead: 0.7.3 Pluto: 0.19.11
I've run into a problem running the transfer learning tutorial in the Model Zoo.
This is a grab from the model definition in Pluto:

This is the error I'm getting:

I've tried various solutions: clearing the cache, reinstalling Juia, updating Flux, MetalHead, and Pluto, but all to no avail. Any advice would be appreciated.
Ronan
It looks like resnet[1:end-2] is producing an empty Chain in your model. That's why taking the gradient throws an error.
Can you post the cells that generate resnet?
Hi,
I used the resnet version in Metalhead:

However, the structure of the network is not the flat one assumed by the resnet[1:end-2] expression. As you suspected, this slice gives an empty chain. I'm still trying to work out how to extract all but the last two layers of the network for transfer learning.
Ronan
Chain(
Chain([
Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false), # 9_408 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
MaxPool((3, 3), pad=1, stride=2),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 64 => 64, bias=false), # 4_096 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((1, 1), 64 => 256, bias=false), # 16_384 parameters
BatchNorm(256), # 512 parameters, plus 512
),
Chain([
Conv((1, 1), 64 => 256, bias=false), # 16_384 parameters
BatchNorm(256), # 512 parameters, plus 512
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 256 => 64, bias=false), # 16_384 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((1, 1), 64 => 256, bias=false), # 16_384 parameters
BatchNorm(256), # 512 parameters, plus 512
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 256 => 64, bias=false), # 16_384 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((1, 1), 64 => 256, bias=false), # 16_384 parameters
BatchNorm(256), # 512 parameters, plus 512
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 256 => 128, bias=false), # 32_768 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, stride=2, bias=false), # 147_456 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((1, 1), 128 => 512, bias=false), # 65_536 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
Chain([
Conv((1, 1), 256 => 512, stride=2, bias=false), # 131_072 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 512 => 128, bias=false), # 65_536 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, bias=false), # 147_456 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((1, 1), 128 => 512, bias=false), # 65_536 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 512 => 128, bias=false), # 65_536 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, bias=false), # 147_456 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((1, 1), 128 => 512, bias=false), # 65_536 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 512 => 128, bias=false), # 65_536 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, bias=false), # 147_456 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((1, 1), 128 => 512, bias=false), # 65_536 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 512 => 256, bias=false), # 131_072 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, stride=2, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
Chain([
Conv((1, 1), 512 => 1024, stride=2, bias=false), # 524_288 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 256, bias=false), # 262_144 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 256, bias=false), # 262_144 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 256, bias=false), # 262_144 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 256, bias=false), # 262_144 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 256, bias=false), # 262_144 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 512, bias=false), # 524_288 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((3, 3), 512 => 512, pad=1, stride=2, bias=false), # 2_359_296 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((1, 1), 512 => 2048, bias=false), # 1_048_576 parameters
BatchNorm(2048), # 4_096 parameters, plus 4_096
),
Chain([
Conv((1, 1), 1024 => 2048, stride=2, bias=false), # 2_097_152 parameters
BatchNorm(2048), # 4_096 parameters, plus 4_096
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 2048 => 512, bias=false), # 1_048_576 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((3, 3), 512 => 512, pad=1, bias=false), # 2_359_296 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((1, 1), 512 => 2048, bias=false), # 1_048_576 parameters
BatchNorm(2048), # 4_096 parameters, plus 4_096
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 2048 => 512, bias=false), # 1_048_576 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((3, 3), 512 => 512, pad=1, bias=false), # 2_359_296 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((1, 1), 512 => 2048, bias=false), # 1_048_576 parameters
BatchNorm(2048), # 4_096 parameters, plus 4_096
),
identity,
),
]),
Chain(
AdaptiveMeanPool((1, 1)),
MLUtils.flatten,
Dense(2048 => 1000), # 2_049_000 parameters
),
) # Total: 161 trainable arrays, 25_557_032 parameters,
# plus 106 non-trainable, 53_120 parameters, summarysize 97.732 MiB.
Ah looks like the tutorial is outdated. Metalhead splits the model into a 2-element Chain of sub-Chains. The first element is the backbone (that you want for transfer learning), and the second element is the classifier. The old Metalhead just had one flat chain.
The "official" way to get what you want is Metalhead.backbone(ResNet()).
Thanks!
Following your suggestion, I've modified the model structure as follows, but am still having trouble adding to the backbone.

This gives the following structure:
Chain(
Chain([
Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false), # 9_408 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
MaxPool((3, 3), pad=1, stride=2),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 64 => 64, bias=false), # 4_096 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((1, 1), 64 => 256, bias=false), # 16_384 parameters
BatchNorm(256), # 512 parameters, plus 512
),
Chain([
Conv((1, 1), 64 => 256, bias=false), # 16_384 parameters
BatchNorm(256), # 512 parameters, plus 512
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 256 => 64, bias=false), # 16_384 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((1, 1), 64 => 256, bias=false), # 16_384 parameters
BatchNorm(256), # 512 parameters, plus 512
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 256 => 64, bias=false), # 16_384 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((1, 1), 64 => 256, bias=false), # 16_384 parameters
BatchNorm(256), # 512 parameters, plus 512
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 256 => 128, bias=false), # 32_768 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, stride=2, bias=false), # 147_456 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((1, 1), 128 => 512, bias=false), # 65_536 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
Chain([
Conv((1, 1), 256 => 512, stride=2, bias=false), # 131_072 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 512 => 128, bias=false), # 65_536 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, bias=false), # 147_456 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((1, 1), 128 => 512, bias=false), # 65_536 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 512 => 128, bias=false), # 65_536 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, bias=false), # 147_456 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((1, 1), 128 => 512, bias=false), # 65_536 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 512 => 128, bias=false), # 65_536 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, bias=false), # 147_456 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((1, 1), 128 => 512, bias=false), # 65_536 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 512 => 256, bias=false), # 131_072 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, stride=2, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
Chain([
Conv((1, 1), 512 => 1024, stride=2, bias=false), # 524_288 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 256, bias=false), # 262_144 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 256, bias=false), # 262_144 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 256, bias=false), # 262_144 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 256, bias=false), # 262_144 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 256, bias=false), # 262_144 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1, bias=false), # 589_824 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 1024, bias=false), # 262_144 parameters
BatchNorm(1024), # 2_048 parameters, plus 2_048
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 1024 => 512, bias=false), # 524_288 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((3, 3), 512 => 512, pad=1, stride=2, bias=false), # 2_359_296 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((1, 1), 512 => 2048, bias=false), # 1_048_576 parameters
BatchNorm(2048), # 4_096 parameters, plus 4_096
),
Chain([
Conv((1, 1), 1024 => 2048, stride=2, bias=false), # 2_097_152 parameters
BatchNorm(2048), # 4_096 parameters, plus 4_096
]),
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 2048 => 512, bias=false), # 1_048_576 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((3, 3), 512 => 512, pad=1, bias=false), # 2_359_296 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((1, 1), 512 => 2048, bias=false), # 1_048_576 parameters
BatchNorm(2048), # 4_096 parameters, plus 4_096
),
identity,
),
Parallel(
Metalhead.addrelu,
Chain(
Conv((1, 1), 2048 => 512, bias=false), # 1_048_576 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((3, 3), 512 => 512, pad=1, bias=false), # 2_359_296 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((1, 1), 512 => 2048, bias=false), # 1_048_576 parameters
BatchNorm(2048), # 4_096 parameters, plus 4_096
),
identity,
),
]),
Chain(
Dense(2048 => 1000), # 2_049_000 parameters
Dense(1000 => 256), # 256_256 parameters
Dense(256 => 2), # 514 parameters
),
)
However, this throws a dimension mismatch error, but I can't work out where the B dimensions come from in the error message:

Ronan
You need a pooling layer or other reduction (+ flatten) like the default classification head uses in order to eliminate the spatial dimensions from the backbone output. Running the forward pass of the network should give a clearer error, and Flux.outputsize is always useful when debugging model size mismatches.
Great! That did the trick.
Many thanks to everybody for all the help on this.
Ronan