neural-motifs icon indicating copy to clipboard operation
neural-motifs copied to clipboard

KeyError: 'roi_fmap.0.weight' in train_models_sgcls.sh

Open rrryan2016 opened this issue 5 years ago • 2 comments

Hey there,

I came across a problem similar to this #41 ,when I did the command bash train_models_sgcls.sh, I got the error message listed later.

I renamed the downloaded file vgrel-motifnet-sgcls.tar to vg-24.tar, and put it in the correct folder. Its size is 1.7 GB (1,676,245,156 bytes).

Besides, both pretrain_detector and refine_for_detection run well.

I guess the main problem lies in ckpt , but I have no idea what to do later.

Error Message (neural_motifs) nopanic@ghostInSh3ll:/media/nopanic/DATA/AnExperiment/NeuralMotifs/neural-motifs-master/scripts$ bash train_models_sgcls.sh TRAINING THE BASELINE

coco : False
ckpt : checkpoints/vgdet/vg-24.tar
det_ckpt : 
save_dir : checkpoints/baseline2
num_gpus : 1
num_workers : 1
lr : 0.001
batch_size : 6
val_size : 5000
l2 : 0.0001
clip : 5.0
print_interval : 100
mode : sgcls
model : motifnet
old_feats : False
order : confidence
cache : 
gt_box : False
adam : False
test : False
multi_pred : False
num_epochs : 50
use_resnet : False
use_proposals : False
nl_obj : 0
nl_edge : 0
hidden_dim : 256
pooling_dim : 4096
pass_in_obj_feats_to_decoder : False
pass_in_obj_feats_to_edge : False
rec_dropout : 0.1
use_bias : True
use_tanh : False
limit_vision : False
loading word vectors from /media/nopanic/DATA/AnExperiment/NeuralMotifs/neural-motifs-master/data/glove.6B.200d.pt
__background__ -> __background__ 
fail on __background__

 385.5M total parameters 
 ----- 
 
detector.roi_fmap.0.weight                        : [4096,25088]    (102760448) (    )
roi_fmap.1.0.weight                               : [4096,25088]    (102760448) (grad)
roi_fmap_obj.0.weight                             : [4096,25088]    (102760448) (grad)
detector.roi_fmap.3.weight                        : [4096,4096]     (16777216) (    )
roi_fmap.1.3.weight                               : [4096,4096]     (16777216) (grad)
roi_fmap_obj.3.weight                             : [4096,4096]     (16777216) (grad)
detector.bbox_fc.weight                           : [604,4096]      ( 2473984) (    )
detector.features.19.weight                       : [512,512,3,3]   ( 2359296) (    )
detector.features.21.weight                       : [512,512,3,3]   ( 2359296) (    )
detector.features.24.weight                       : [512,512,3,3]   ( 2359296) (    )
detector.features.26.weight                       : [512,512,3,3]   ( 2359296) (    )
detector.features.28.weight                       : [512,512,3,3]   ( 2359296) (    )
detector.rpn_head.conv.0.weight                   : [512,512,3,3]   ( 2359296) (    )
post_lstm.weight                                  : [8192,256]      ( 2097152) (grad)
post_emb.weight                                   : [151,8192]      ( 1236992) (grad)
detector.features.17.weight                       : [512,256,3,3]   ( 1179648) (    )
union_boxes.conv.4.weight                         : [512,256,3,3]   ( 1179648) (grad)
freq_bias.obj_baseline.weight                     : [22801,51]      ( 1162851) (grad)
context.decoder_lin.weight                        : [151,4424]      (  668024) (grad)
detector.score_fc.weight                          : [151,4096]      (  618496) (    )
detector.features.12.weight                       : [256,256,3,3]   (  589824) (    )
detector.features.14.weight                       : [256,256,3,3]   (  589824) (    )
detector.features.10.weight                       : [256,128,3,3]   (  294912) (    )
rel_compress.weight                               : [51,4096]       (  208896) (grad)
detector.features.7.weight                        : [128,128,3,3]   (  147456) (    )
detector.features.5.weight                        : [128,64,3,3]    (   73728) (    )
detector.rpn_head.conv.2.weight                   : [120,512,1,1]   (   61440) (    )
detector.features.2.weight                        : [64,64,3,3]     (   36864) (    )
context.obj_embed.weight                          : [151,200]       (   30200) (grad)
context.obj_embed2.weight                         : [151,200]       (   30200) (grad)
union_boxes.conv.0.weight                         : [256,2,7,7]     (   25088) (grad)
detector.features.0.weight                        : [64,3,3,3]      (    1728) (    )
context.pos_embed.1.weight                        : [128,4]         (     512) (grad)
union_boxes.conv.6.weight                         : [512]           (     512) (grad)
union_boxes.conv.2.weight                         : [256]           (     256) (grad)
context.pos_embed.0.weight                        : [4]             (       4) (grad)
Unexpected key detector.features.0.weight in state_dict with size torch.Size([64, 3, 3, 3])
Unexpected key detector.features.0.bias in state_dict with size torch.Size([64])
Unexpected key detector.features.2.weight in state_dict with size torch.Size([64, 64, 3, 3])
Unexpected key detector.features.2.bias in state_dict with size torch.Size([64])
Unexpected key detector.features.5.weight in state_dict with size torch.Size([128, 64, 3, 3])
Unexpected key detector.features.5.bias in state_dict with size torch.Size([128])
Unexpected key detector.features.7.weight in state_dict with size torch.Size([128, 128, 3, 3])
Unexpected key detector.features.7.bias in state_dict with size torch.Size([128])
Unexpected key detector.features.10.weight in state_dict with size torch.Size([256, 128, 3, 3])
Unexpected key detector.features.10.bias in state_dict with size torch.Size([256])
Unexpected key detector.features.12.weight in state_dict with size torch.Size([256, 256, 3, 3])
Unexpected key detector.features.12.bias in state_dict with size torch.Size([256])
Unexpected key detector.features.14.weight in state_dict with size torch.Size([256, 256, 3, 3])
Unexpected key detector.features.14.bias in state_dict with size torch.Size([256])
Unexpected key detector.features.17.weight in state_dict with size torch.Size([512, 256, 3, 3])
Unexpected key detector.features.17.bias in state_dict with size torch.Size([512])
Unexpected key detector.features.19.weight in state_dict with size torch.Size([512, 512, 3, 3])
Unexpected key detector.features.19.bias in state_dict with size torch.Size([512])
Unexpected key detector.features.21.weight in state_dict with size torch.Size([512, 512, 3, 3])
Unexpected key detector.features.21.bias in state_dict with size torch.Size([512])
Unexpected key detector.features.24.weight in state_dict with size torch.Size([512, 512, 3, 3])
Unexpected key detector.features.24.bias in state_dict with size torch.Size([512])
Unexpected key detector.features.26.weight in state_dict with size torch.Size([512, 512, 3, 3])
Unexpected key detector.features.26.bias in state_dict with size torch.Size([512])
Unexpected key detector.features.28.weight in state_dict with size torch.Size([512, 512, 3, 3])
Unexpected key detector.features.28.bias in state_dict with size torch.Size([512])
Unexpected key detector.roi_fmap.0.weight in state_dict with size torch.Size([4096, 25088])
Unexpected key detector.roi_fmap.0.bias in state_dict with size torch.Size([4096])
Unexpected key detector.roi_fmap.3.weight in state_dict with size torch.Size([4096, 4096])
Unexpected key detector.roi_fmap.3.bias in state_dict with size torch.Size([4096])
Unexpected key detector.score_fc.weight in state_dict with size torch.Size([151, 4096])
Unexpected key detector.score_fc.bias in state_dict with size torch.Size([151])
Unexpected key detector.bbox_fc.weight in state_dict with size torch.Size([604, 4096])
Unexpected key detector.bbox_fc.bias in state_dict with size torch.Size([604])
Unexpected key detector.rpn_head.anchors in state_dict with size torch.Size([37, 37, 20, 4])
Unexpected key detector.rpn_head.conv.0.weight in state_dict with size torch.Size([512, 512, 3, 3])
Unexpected key detector.rpn_head.conv.0.bias in state_dict with size torch.Size([512])
Unexpected key detector.rpn_head.conv.2.weight in state_dict with size torch.Size([120, 512, 1, 1])
Unexpected key detector.rpn_head.conv.2.bias in state_dict with size torch.Size([120])
Unexpected key context.obj_embed.weight in state_dict with size torch.Size([151, 200])
Unexpected key context.obj_embed2.weight in state_dict with size torch.Size([151, 200])
Unexpected key context.pos_embed.0.weight in state_dict with size torch.Size([4])
Unexpected key context.pos_embed.0.bias in state_dict with size torch.Size([4])
Unexpected key context.pos_embed.0.running_mean in state_dict with size torch.Size([4])
Unexpected key context.pos_embed.0.running_var in state_dict with size torch.Size([4])
Unexpected key context.pos_embed.1.weight in state_dict with size torch.Size([128, 4])
Unexpected key context.pos_embed.1.bias in state_dict with size torch.Size([128])
Unexpected key context.obj_ctx_rnn.weight in state_dict with size torch.Size([17784832])
Unexpected key context.obj_ctx_rnn.bias in state_dict with size torch.Size([5120])
Unexpected key context.decoder_rnn.obj_embed.weight in state_dict with size torch.Size([152, 100])
Unexpected key context.decoder_rnn.input_linearity.weight in state_dict with size torch.Size([3072, 612])
Unexpected key context.decoder_rnn.input_linearity.bias in state_dict with size torch.Size([3072])
Unexpected key context.decoder_rnn.state_linearity.weight in state_dict with size torch.Size([2560, 512])
Unexpected key context.decoder_rnn.state_linearity.bias in state_dict with size torch.Size([2560])
Unexpected key context.decoder_rnn.out.weight in state_dict with size torch.Size([151, 512])
Unexpected key context.decoder_rnn.out.bias in state_dict with size torch.Size([151])
Unexpected key context.edge_ctx_rnn.weight in state_dict with size torch.Size([12148736])
Unexpected key context.edge_ctx_rnn.bias in state_dict with size torch.Size([10240])
Unexpected key union_boxes.conv.0.weight in state_dict with size torch.Size([256, 2, 7, 7])
Unexpected key union_boxes.conv.0.bias in state_dict with size torch.Size([256])
Unexpected key union_boxes.conv.2.weight in state_dict with size torch.Size([256])
Unexpected key union_boxes.conv.2.bias in state_dict with size torch.Size([256])
Unexpected key union_boxes.conv.2.running_mean in state_dict with size torch.Size([256])
Unexpected key union_boxes.conv.2.running_var in state_dict with size torch.Size([256])
Unexpected key union_boxes.conv.4.weight in state_dict with size torch.Size([512, 256, 3, 3])
Unexpected key union_boxes.conv.4.bias in state_dict with size torch.Size([512])
Unexpected key union_boxes.conv.6.weight in state_dict with size torch.Size([512])
Unexpected key union_boxes.conv.6.bias in state_dict with size torch.Size([512])
Unexpected key union_boxes.conv.6.running_mean in state_dict with size torch.Size([512])
Unexpected key union_boxes.conv.6.running_var in state_dict with size torch.Size([512])
Unexpected key roi_fmap.1.0.weight in state_dict with size torch.Size([4096, 25088])
Unexpected key roi_fmap.1.0.bias in state_dict with size torch.Size([4096])
Unexpected key roi_fmap.1.3.weight in state_dict with size torch.Size([4096, 4096])
Unexpected key roi_fmap.1.3.bias in state_dict with size torch.Size([4096])
Unexpected key roi_fmap_obj.0.weight in state_dict with size torch.Size([4096, 25088])
Unexpected key roi_fmap_obj.0.bias in state_dict with size torch.Size([4096])
Unexpected key roi_fmap_obj.3.weight in state_dict with size torch.Size([4096, 4096])
Unexpected key roi_fmap_obj.3.bias in state_dict with size torch.Size([4096])
Unexpected key post_lstm.weight in state_dict with size torch.Size([8192, 512])
Unexpected key post_lstm.bias in state_dict with size torch.Size([8192])
Unexpected key rel_compress.weight in state_dict with size torch.Size([51, 4096])
Unexpected key rel_compress.bias in state_dict with size torch.Size([51])
Unexpected key freq_bias.obj_baseline.weight in state_dict with size torch.Size([22801, 51])
We couldn't find features.19.bias,roi_fmap.3.weight,rpn_head.conv.2.weight,features.10.bias,features.17.bias,roi_fmap.3.bias,score_fc.bias,roi_fmap.0.bias,features.21.bias,features.24.weight,features.12.bias,features.26.bias,features.21.weight,features.10.weight,features.28.bias,features.7.bias,rpn_head.conv.2.bias,features.17.weight,features.12.weight,features.5.bias,rpn_head.anchors,rpn_head.conv.0.bias,bbox_fc.bias,features.14.weight,bbox_fc.weight,features.14.bias,features.19.weight,features.0.weight,features.7.weight,features.24.bias,score_fc.weight,features.5.weight,features.26.weight,features.28.weight,features.2.bias,roi_fmap.0.weight,rpn_head.conv.0.weight,features.0.bias,features.2.weight
Traceback (most recent call last):
  File "/media/nopanic/DATA/AnExperiment/NeuralMotifs/neural-motifs-master/models/train_rels.py", line 87, in <module>
    detector.roi_fmap[1][0].weight.data.copy_(ckpt['state_dict']['roi_fmap.0.weight'])
KeyError: 'roi_fmap.0.weight'

rrryan2016 avatar Mar 25 '19 11:03 rrryan2016

I think the problem is that you have a checkpoint for the sgcls model that you're treating like a detector checkpoint. If you download the provided detector checkpoint that should work!

rowanz avatar Mar 25 '19 17:03 rowanz

@rowanz Hello! I have the same problem, I downloaded the vgrel-motifnet-sgcls.tar file provided by the author on GitHub and stored it under checkpoints/vgdet/。

TP0609 avatar Nov 25 '19 07:11 TP0609