Variations-of-SFANet-for-Crowd-Counting icon indicating copy to clipboard operation
Variations-of-SFANet-for-Crowd-Counting copied to clipboard

could not reproduce the res in the paper

Open knightyxp opened this issue 3 years ago • 8 comments

the published code do not have CANet branch, not coordinated with the paper report, the baseline of SFANet is 59 according to my exp, however, when I add ASPP(means using the M-SFANet model according to the author's code) the SHA MAE is only 61, ridiculous, the res in the paper can not be reproduced (i do not know whether the reviewer of ICPR know this thing)

knightyxp avatar May 20 '21 10:05 knightyxp

@knightyxp In the ScalePyramidModule class, defined in M-SFANet.py, there is self.can = ContextualModule(512, 512) as the CANet branch. If the CAN module is deducted, the reported performance was 62.41 (MAE) on SHA and 7.40 (MAE) on SHB. Please ensure to include the module in your forward pass as well.

Pongpisit-Thanasutives avatar May 20 '21 13:05 Pongpisit-Thanasutives

whether this means just add can in sfan is not suitable

knightyxp avatar May 20 '21 16:05 knightyxp

what is more, in your exp M-SFANet w/o CAN(this means just sfan) the MAE on SHA is 62.41, however the sfan on sha could achieve 59(my exp) reported 60 , so i do not know how u get this res 截屏2021-05-21 上午12 49 12

knightyxp avatar May 20 '21 16:05 knightyxp

@knightyxp I see. M-SFANet w/o CAN means deducting self.can in ScalePyramidModule. Can I have your training code? It's hard to notice the difference in implementation.

Pongpisit-Thanasutives avatar May 21 '21 03:05 Pongpisit-Thanasutives

same as train.py in sfanet

knightyxp avatar May 21 '21 08:05 knightyxp

Thank you. I have seen your codes and spotted some inconsistencies in your implementation:

(1) I did not use the SSM loss. (2) I did not use Adam. In the SHA experiment, I used LookaheadAdam(model.parameters(), lr=5e-4) (See ./models). (3) Please train up to 1000 epochs, not 500, because the M-SFANet model is more complex in terms of #params. (4) The reproduced weights of M_SFANet (SHA) with MAE=59.69 and MSE=95.64 are provided via the google drive link. So you can check the saved epoch and the optimizer's state_dict.

P.S. I cannot check your preprocessing code, which is also important.

Pongpisit-Thanasutives avatar May 21 '21 09:05 Pongpisit-Thanasutives

can not load your pretrain_pth error like this: Traceback (most recent call last): File "test.py", line 40, in model.load_state_dict(torch.load(model_path), device) File "/opt/conda/envs/torch17/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Model: Missing key(s) in state_dict: "vgg.conv1_1.conv.weight", "vgg.conv1_1.conv.bias", "vgg.conv1_1.bn.weight", "vgg.conv1_1.bn.bias", "vgg.conv1_1.bn.running_mean", "vgg.conv1_1.bn.running_var", "vgg.conv1_2.conv.weight", "vgg.conv1_2.conv.bias", "vgg.conv1_2.bn.weight", "vgg.conv1_2.bn.bias", "vgg.conv1_2.bn.running_mean", "vgg.conv1_2.bn.running_var", "vgg.conv2_1.conv.weight", "vgg.conv2_1.conv.bias", "vgg.conv2_1.bn.weight", "vgg.conv2_1.bn.bias", "vgg.conv2_1.bn.running_mean", "vgg.conv2_1.bn.running_var", "vgg.conv2_2.conv.weight", "vgg.conv2_2.conv.bias", "vgg.conv2_2.bn.weight", "vgg.conv2_2.bn.bias", "vgg.conv2_2.bn.running_mean", "vgg.conv2_2.bn.running_var", "vgg.conv3_1.conv.weight", "vgg.conv3_1.conv.bias", "vgg.conv3_1.bn.weight", "vgg.conv3_1.bn.bias", "vgg.conv3_1.bn.running_mean", "vgg.conv3_1.bn.running_var", "vgg.conv3_2.conv.weight", "vgg.conv3_2.conv.bias", "vgg.conv3_2.bn.weight", "vgg.conv3_2.bn.bias", "vgg.conv3_2.bn.running_mean", "vgg.conv3_2.bn.running_var", "vgg.conv3_3.conv.weight", "vgg.conv3_3.conv.bias", "vgg.conv3_3.bn.weight", "vgg.conv3_3.bn.bias", "vgg.conv3_3.bn.running_mean", "vgg.conv3_3.bn.running_var", "vgg.conv4_1.conv.weight", "vgg.conv4_1.conv.bias", "vgg.conv4_1.bn.weight", "vgg.conv4_1.bn.bias", "vgg.conv4_1.bn.running_mean", "vgg.conv4_1.bn.running_var", "vgg.conv4_2.conv.weight", "vgg.conv4_2.conv.bias", "vgg.conv4_2.bn.weight", "vgg.conv4_2.bn.bias", "vgg.conv4_2.bn.running_mean", "vgg.conv4_2.bn.running_var", "vgg.conv4_3.conv.weight", "vgg.conv4_3.conv.bias", "vgg.conv4_3.bn.weight", "vgg.conv4_3.bn.bias", "vgg.conv4_3.bn.running_mean", "vgg.conv4_3.bn.running_var", "vgg.conv5_1.conv.weight", "vgg.conv5_1.conv.bias", "vgg.conv5_1.bn.weight", "vgg.conv5_1.bn.bias", "vgg.conv5_1.bn.running_mean", "vgg.conv5_1.bn.running_var", "vgg.conv5_2.conv.weight", "vgg.conv5_2.conv.bias", "vgg.conv5_2.bn.weight", "vgg.conv5_2.bn.bias", "vgg.conv5_2.bn.running_mean", "vgg.conv5_2.bn.running_var", "vgg.conv5_3.conv.weight", "vgg.conv5_3.conv.bias", "vgg.conv5_3.bn.weight", "vgg.conv5_3.bn.bias", "vgg.conv5_3.bn.running_mean", "vgg.conv5_3.bn.running_var", "spm.assp.aspp1.atrous_conv.weight", "spm.assp.aspp1.bn.weight", "spm.assp.aspp1.bn.bias", "spm.assp.aspp1.bn.running_mean", "spm.assp.aspp1.bn.running_var", "spm.assp.aspp2.atrous_conv.weight", "spm.assp.aspp2.bn.weight", "spm.assp.aspp2.bn.bias", "spm.assp.aspp2.bn.running_mean", "spm.assp.aspp2.bn.running_var", "spm.assp.aspp3.atrous_conv.weight", "spm.assp.aspp3.bn.weight", "spm.assp.aspp3.bn.bias", "spm.assp.aspp3.bn.running_mean", "spm.assp.aspp3.bn.running_var", "spm.assp.aspp4.atrous_conv.weight", "spm.assp.aspp4.bn.weight", "spm.assp.aspp4.bn.bias", "spm.assp.aspp4.bn.running_mean", "spm.assp.aspp4.bn.running_var", "spm.assp.global_avg_pool.1.weight", "spm.assp.global_avg_pool.2.weight", "spm.assp.global_avg_pool.2.bias", "spm.assp.global_avg_pool.2.running_mean", "spm.assp.global_avg_pool.2.running_var", "spm.assp.conv1.weight", "spm.assp.bn1.weight", "spm.assp.bn1.bias", "spm.assp.bn1.running_mean", "spm.assp.bn1.running_var", "spm.can.scales.0.1.weight", "spm.can.scales.1.1.weight", "spm.can.scales.2.1.weight", "spm.can.scales.3.1.weight", "spm.can.bottleneck.weight", "spm.can.bottleneck.bias", "spm.can.weight_net.weight", "spm.can.weight_net.bias", "amp.conv1.conv.weight", "amp.conv1.conv.bias", "amp.conv1.bn.weight", "amp.conv1.bn.bias", "amp.conv1.bn.running_mean", "amp.conv1.bn.running_var", "amp.conv2.conv.weight", "amp.conv2.conv.bias", "amp.conv2.bn.weight", "amp.conv2.bn.bias", "amp.conv2.bn.running_mean", "amp.conv2.bn.running_var", "amp.conv3.conv.weight", "amp.conv3.conv.bias", "amp.conv3.bn.weight", "amp.conv3.bn.bias", "amp.conv3.bn.running_mean", "amp.conv3.bn.running_var", "amp.conv4.conv.weight", "amp.conv4.conv.bias", "amp.conv4.bn.weight", "amp.conv4.bn.bias", "amp.conv4.bn.running_mean", "amp.conv4.bn.running_var", "amp.conv5.conv.weight", "amp.conv5.conv.bias", "amp.conv5.bn.weight", "amp.conv5.bn.bias", "amp.conv5.bn.running_mean", "amp.conv5.bn.running_var", "amp.conv6.conv.weight", "amp.conv6.conv.bias", "amp.conv6.bn.weight", "amp.conv6.bn.bias", "amp.conv6.bn.running_mean", "amp.conv6.bn.running_var", "amp.conv7.conv.weight", "amp.conv7.conv.bias", "amp.conv7.bn.weight", "amp.conv7.bn.bias", "amp.conv7.bn.running_mean", "amp.conv7.bn.running_var", "dmp.conv1.conv.weight", "dmp.conv1.conv.bias", "dmp.conv1.bn.weight", "dmp.conv1.bn.bias", "dmp.conv1.bn.running_mean", "dmp.conv1.bn.running_var", "dmp.conv2.conv.weight", "dmp.conv2.conv.bias", "dmp.conv2.bn.weight", "dmp.conv2.bn.bias", "dmp.conv2.bn.running_mean", "dmp.conv2.bn.running_var", "dmp.conv3.conv.weight", "dmp.conv3.conv.bias", "dmp.conv3.bn.weight", "dmp.conv3.bn.bias", "dmp.conv3.bn.running_mean", "dmp.conv3.bn.running_var", "dmp.conv4.conv.weight", "dmp.conv4.conv.bias", "dmp.conv4.bn.weight", "dmp.conv4.bn.bias", "dmp.conv4.bn.running_mean", "dmp.conv4.bn.running_var", "dmp.conv5.conv.weight", "dmp.conv5.conv.bias", "dmp.conv5.bn.weight", "dmp.conv5.bn.bias", "dmp.conv5.bn.running_mean", "dmp.conv5.bn.running_var", "dmp.conv6.conv.weight", "dmp.conv6.conv.bias", "dmp.conv6.bn.weight", "dmp.conv6.bn.bias", "dmp.conv6.bn.running_mean", "dmp.conv6.bn.running_var", "dmp.conv7.conv.weight", "dmp.conv7.conv.bias", "dmp.conv7.bn.weight", "dmp.conv7.bn.bias", "dmp.conv7.bn.running_mean", "dmp.conv7.bn.running_var", "conv_att.conv.weight", "conv_att.conv.bias", "conv_att.bn.weight", "conv_att.bn.bias", "conv_att.bn.running_mean", "conv_att.bn.running_var", "conv_out.conv.weight", "conv_out.conv.bias", "conv_out.bn.weight", "conv_out.bn.bias", "conv_out.bn.running_mean", "conv_out.bn.running_var". Unexpected key(s) in state_dict: "epoch", "model", "optimizer", "mae", "mse".

knightyxp avatar May 25 '21 10:05 knightyxp

@knightyxp The weights are stored in the torch.load(model_path)["model"]. Like this one. Screen Shot 2564-05-25 at 20 40 33

Pongpisit-Thanasutives avatar May 25 '21 11:05 Pongpisit-Thanasutives