Efficient-Transformer icon indicating copy to clipboard operation
Efficient-Transformer copied to clipboard

How can i predict single picture with trained models

Open wondering516 opened this issue 2 years ago • 10 comments

i want to build an application which can predict single picture, where should i modify intest.py; another question is can i use CPU to run test.py

wondering516 avatar Aug 10 '22 01:08 wondering516

1, If you just run a single picture, the dataloader is not necessary. You can read the image via using cv2.imread, and then transform it to tensor with the shape of [b, c, h, w], b is set to 1 here. And the distributed inference can also be changed to one CPU or GPU device. 2, The current code base does not support CPU inference. If you want to use CPU, you should change the CUDA related settings to CPU applications

zyxu1996 avatar Aug 10 '22 06:08 zyxu1996

First question

I use

model = swinT(nclass=2, pretrained=False, aux=True, head="mlphead")
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

to load model ,but i get

    model.load_state_dict(checkpoint)
  File "/home/Zhenxin.Wang/anaconda3/envs/effitrans/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for MultiEvalModule:
        Missing key(s) in state_dict: "module.backbone.patch_embed.proj.weight",..........```

## Second question
As what you say , i read single image without dataloader, use next code to transform the shape to `[b,c,h,w,]`

images = np.array(cv2.imread(image_path)) images = images.transpose(2, 0, 1) images = torch.tensor(images) images = images.view(1, *images.size())

but i get 

Traceback (most recent call last): File "predict_single.py", line 422, in predict_image(model, weight_dir, dataloader_val_full) File "predict_single.py", line 377, in predict_image logits = model(images) File "/home/Zhenxin.Wang/anaconda3/envs/effitrans/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "predict_single.py", line 193, in forward if long_size <= np.max(crop_size): File "<array_function internals>", line 6, in amax File "/home/Zhenxin.Wang/anaconda3/envs/effitrans/lib/python3.7/site-packages/numpy/core/fromnumeric.py", line 2755, in amax keepdims=keepdims, initial=initial, where=where) File "/home/Zhenxin.Wang/anaconda3/envs/effitrans/lib/python3.7/site-packages/numpy/core/fromnumeric.py", line 86, in _wrapreduction return ufunc.reduce(obj, axis, dtype, out, **passkwargs) ValueError: zero-size array to reduction operation maximum which has no identity

pls help me, i just want to predict single image in my own PC without GPU, here is my `predict_single.py`file

[predict_single.txt](https://github.com/zyxu1996/Efficient-Transformer/files/9307699/predict_single.txt)

wondering516 avatar Aug 11 '22 09:08 wondering516

I find the errors and fix it, now it can predict single image, but still need GPU, i try to replace cuda related with cpu, but the precision has dropped a lot, where should I modify to use best_weight.pkl to predict image on CPU machines, please help me.

wondering516 avatar Aug 11 '22 14:08 wondering516

如果cpu用对了,不应该会掉点的。你看看是不是读进来的tensor img没有做normalization

---Original--- From: @.> Date: Thu, Aug 11, 2022 22:08 PM To: @.>; Cc: @.@.>; Subject: Re: [zyxu1996/Efficient-Transformer] How can i predict single picturewith trained models (Issue #3)

I find the errors and fix it, now it can predict single image, but still need GPU, i try to replace cuda related with cpu, but the precision has dropped a lot, where should I modify to use best_weight.pkl to predict image on CPU machines, please help me.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

zyxu1996 avatar Aug 11 '22 14:08 zyxu1996

还有,确保best_weight.pkl能正常读进来,这个权重前面有很多model.module这样的前缀,可以删除或增加,保持和model需要的权重的key一致

------------------ 原始邮件 ------------------ 发件人: "徐志勇" @.>; 发送时间: 2022年8月11日(星期四) 晚上10:20 @.>;

主题: Re: [zyxu1996/Efficient-Transformer] How can i predict single picturewith trained models (Issue #3)

如果cpu用对了,不应该会掉点的。你看看是不是读进来的tensor img没有做normalization

---Original--- From: @.> Date: Thu, Aug 11, 2022 22:08 PM To: @.>; Cc: @.@.>; Subject: Re: [zyxu1996/Efficient-Transformer] How can i predict single picturewith trained models (Issue #3)

I find the errors and fix it, now it can predict single image, but still need GPU, i try to replace cuda related with cpu, but the precision has dropped a lot, where should I modify to use best_weight.pkl to predict image on CPU machines, please help me.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

zyxu1996 avatar Aug 11 '22 14:08 zyxu1996

其他代码不变,我注释掉了这两行

model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(model, device_ids=[0, ], output_device=0)

代码就在

model.load_state_dict(checkpoint)

这个地方报

RuntimeError: Error(s) in loading state_dict for MultiEvalModule:
        Missing key(s) in state_dict: "module.backbone.patch_embed.proj.weight", "module.backbone.patch_embed.proj.bias", "module.backbone.patch_embed.norm.weight", "module.backbone.patch_embed.norm.bias", "module.backbone.layers.0.blocks.0.norm1.weight", "module.backbone.layers.0.blocks.0.norm1.bias", "module.backbone.layers.0.blocks.0.attn.relative_position_bias_table", "module.backbone.layers.0.blocks.0.attn.relative_position_index", "module.backbone.layers.0.blocks.0.attn.qkv.weight", "module.backbone.layers.0.blocks.0.attn.qkv.bias", "module.backbone.layers.0.blocks.0.attn.proj.weight", "module.backbone.layers.0.blocks.0.attn.proj.bias", "module.backbone.layers.0.blocks.0.norm2.weight", "module.backbone.layers.0.blocks.0.norm2.bias", "module.backbone.layers.0.blocks.0.mlp.fc1.weight", "module.backbone.layers.0.blocks.0.mlp.fc1.bias", "module.backbone.layers.0.blocks.0.mlp.fc2.weight", "module.backbone.layers.0.blocks.0.mlp.fc2.bias", "module.backbone.layers.0.blocks.1.norm1.weight", "module.backbone.layers.0.blocks.1.norm1.bias", "module.backbone.layers.0.blocks.1.attn.relative_position_bias_table", "module.backbone.layers.0.blocks.1.attn.relative_position_index", "module.backbone.layers.0.blocks.1.attn.qkv.weight", "module.backbone.layers.0.blocks.1.attn.qkv.bias", "module.backbone.layers.0.blocks.1.attn.proj.weight", "module.backbone.layers.0.blocks.1.attn.proj.bias", "module.backbone.layers.0.blocks.1.norm2.weight", "module.backbone.layers.0.blocks.1.norm2.bias", "module.backbone.layers.0.blocks.1.mlp.fc1.weight", "module.backbone.layers.0.blocks.1.mlp.fc1.bias", "module.backbone.layers.0.blocks.1.mlp.fc2.weight", "module.backbone.layers.0.blocks.1.mlp.fc2.bias", "module.backbone.layers.0.downsample.reduction.weight", "module.backbone.layers.0.downsample.norm.weight", "module.backbone.layers.0.downsample.norm.bias", "module.backbone.layers.1.blocks.0.norm1.weight", "module.backbone.layers.1.blocks.0.norm1.bias", "module.backbone.layers.1.blocks.0.attn.relative_position_bias_table", "module.backbone.layers.1.blocks.0.attn.relative_position_index", "module.backbone.layers.1.blocks.0.attn.qkv.weight", "module.backbone.layers.1.blocks.0.attn.qkv.bias", "module.backbone.layers.1.blocks.0.attn.proj.weight", "module.backbone.layers.1.blocks.0.attn.proj.bias", "module.backbone.layers.1.blocks.0.norm2.weight", "module.backbone.layers.1.blocks.0.norm2.bias", "module.backbone.layers.1.blocks.0.mlp.fc1.weight", "module.backbone.layers.1.blocks.0.mlp.fc1.bias", "module.backbone.layers.1.blocks.0.mlp.fc2.weight", "module.backbone.layers.1.blocks.0.mlp.fc2.bias", "module.backbone.layers.1.blocks.1.norm1.weight", "module.backbone.layers.1.blocks.1.norm1.bias", "module.backbone.layers.1.blocks.1.attn.relative_position_bias_table", "module.backbone.layers.1.blocks.1.attn.relative_position_index", "module.backbone.layers.1.blocks.1.attn.qkv.weight", "module.backbone.layers.1.blocks.1.attn.qkv.bias", "module.backbone.layers.1.blocks.1.attn.proj.weight", "module.backbone.layers.1.blocks.1.attn.proj.bias", "module.backbone.layers.1.blocks.1.norm2.weight", "module.backbone.layers.1.blocks.1.norm2.bias", "module.backbone.layers.1.blocks.1.mlp.fc1.weight", "module.backbone.layers.1.blocks.1.mlp.fc1.bias", "module.backbone.layers.1.blocks.1.mlp.fc2.weight", "module.backbone.layers.1.blocks.1.mlp.fc2.bias", "module.backbone.layers.1.downsample.reduction.weight", "module.backbone.layers.1.downsample.norm.weight", "module.backbone.layers.1.downsample.norm.bias", "module.backbone.layers.2.blocks.0.norm1.weight", "module.backbone.layers.2.blocks.0.norm1.bias", "module.backbone.layers.2.blocks.0.attn.relative_position_bias_table", "module.backbone.layers.2.blocks.0.attn.relative_position_index", "module.backbone.layers.2.blocks.0.attn.qkv.weight", "module.backbone.layers.2.blocks.0.attn.qkv.bias", "module.backbone.layers.2.blocks.0.attn.proj.weight", "module.backbone.layers.2.blocks.0.attn.proj.bias", "module.backbone.layers.2.blocks.0.norm2.weight", "module.backbone.layers.2.blocks.0.norm2.bias", "module.backbone.layers.2.blocks.0.mlp.fc1.weight", "module.backbone.layers.2.blocks.0.mlp.fc1.bias", "module.backbone.layers.2.blocks.0.mlp.fc2.weight", "module.backbone.layers.2.blocks.0.mlp.fc2.bias", "module.backbone.layers.2.blocks.1.norm1.weight", "module.backbone.layers.2.blocks.1.norm1.bias", "module.backbone.layers.2.blocks.1.attn.relative_position_bias_table", "module.backbone.layers.2.blocks.1.attn.relative_position_index", "module.backbone.layers.2.blocks.1.attn.qkv.weight", "module.backbone.layers.2.blocks.1.attn.qkv.bias", "module.backbone.layers.2.blocks.1.attn.proj.weight", "module.backbone.layers.2.blocks.1.attn.proj.bias", "module.backbone.layers.2.blocks.1.norm2.weight", "module.backbone.layers.2.blocks.1.norm2.bias", "module.backbone.layers.2.blocks.1.mlp.fc1.weight", "module.backbone.layers.2.blocks.1.mlp.fc1.bias", "module.backbone.layers.2.blocks.1.mlp.fc2.weight", "module.backbone.layers.2.blocks.1.mlp.fc2.bias", "module.backbone.layers.2.blocks.2.norm1.weight", "module.backbone.layers.2.blocks.2.norm1.bias", "module.backbone.layers.2.blocks.2.attn.relative_position_bias_table", "module.backbone.layers.2.blocks.2.attn.relative_position_index", "module.backbone.layers.2.blocks.2.attn.qkv.weight", "module.backbone.layers.2.blocks.2.attn.qkv.bias", "module.backbone.layers.2.blocks.2.attn.proj.weight", "module.backbone.layers.2.blocks.2.attn.proj.bias", "module.backbone.layers.2.blocks.2.norm2.weight", "module.backbone.layers.2.blocks.2.norm2.bias", "module.backbone.layers.2.blocks.2.mlp.fc1.weight", "module.backbone.layers.2.blocks.2.mlp.fc1.bias", "module.backbone.layers.2.blocks.2.mlp.fc2.weight", "module.backbone.layers.2.blocks.2.mlp.fc2.bias", "module.backbone.layers.2.blocks.3.norm1.weight", "module.backbone.layers.2.blocks.3.norm1.bias", "module.backbone.layers.2.blocks.3.attn.relative_position_bias_table", "module.backbone.layers.2.blocks.3.attn.relative_position_index", "module.backbone.layers.2.blocks.3.attn.qkv.weight", "module.backbone.layers.2.blocks.3.attn.qkv.bias", "module.backbone.layers.2.blocks.3.attn.proj.weight", "module.backbone.layers.2.blocks.3.attn.proj.bias", "module.backbone.layers.2.blocks.3.norm2.weight", "module.backbone.layers.2.blocks.3.norm2.bias", "module.backbone.layers.2.blocks.3.mlp.fc1.weight", "module.backbone.layers.2.blocks.3.mlp.fc1.bias", "module.backbone.layers.2.blocks.3.mlp.fc2.weight", "module.backbone.layers.2.blocks.3.mlp.fc2.bias", "module.backbone.layers.2.blocks.4.norm1.weight", "module.backbone.layers.2.blocks.4.norm1.bias", "module.backbone.layers.2.blocks.4.attn.relative_position_bias_table", "module.backbone.layers.2.blocks.4.attn.relative_position_index", "module.backbone.layers.2.blocks.4.attn.qkv.weight", "module.backbone.layers.2.blocks.4.attn.qkv.bias", "module.backbone.layers.2.blocks.4.attn.proj.weight", "module.backbone.layers.2.blocks.4.attn.proj.bias", "module.backbone.layers.2.blocks.4.norm2.weight", "module.backbone.layers.2.blocks.4.norm2.bias", "module.backbone.layers.2.blocks.4.mlp.fc1.weight", "module.backbone.layers.2.blocks.4.mlp.fc1.bias", "module.backbone.layers.2.blocks.4.mlp.fc2.weight", "module.backbone.layers.2.blocks.4.mlp.fc2.bias", "module.backbone.layers.2.blocks.5.norm1.weight", "module.backbone.layers.2.blocks.5.norm1.bias", "module.backbone.layers.2.blocks.5.attn.relative_position_bias_table", "module.backbone.layers.2.blocks.5.attn.relative_position_index", "module.backbone.layers.2.blocks.5.attn.qkv.weight", "module.backbone.layers.2.blocks.5.attn.qkv.bias", "module.backbone.layers.2.blocks.5.attn.proj.weight", "module.backbone.layers.2.blocks.5.attn.proj.bias", "module.backbone.layers.2.blocks.5.norm2.weight", "module.backbone.layers.2.blocks.5.norm2.bias", "module.backbone.layers.2.blocks.5.mlp.fc1.weight", "module.backbone.layers.2.blocks.5.mlp.fc1.bias", "module.backbone.layers.2.blocks.5.mlp.fc2.weight", "module.backbone.layers.2.blocks.5.mlp.fc2.bias", "module.backbone.layers.2.downsample.reduction.weight", "module.backbone.layers.2.downsample.norm.weight", "module.backbone.layers.2.downsample.norm.bias", "module.backbone.layers.3.blocks.0.norm1.weight", "module.backbone.layers.3.blocks.0.norm1.bias", "module.backbone.layers.3.blocks.0.attn.relative_position_bias_table", "module.backbone.layers.3.blocks.0.attn.relative_position_index", "module.backbone.layers.3.blocks.0.attn.qkv.weight", "module.backbone.layers.3.blocks.0.attn.qkv.bias", "module.backbone.layers.3.blocks.0.attn.proj.weight", "module.backbone.layers.3.blocks.0.attn.proj.bias", "module.backbone.layers.3.blocks.0.norm2.weight", "module.backbone.layers.3.blocks.0.norm2.bias", "module.backbone.layers.3.blocks.0.mlp.fc1.weight", "module.backbone.layers.3.blocks.0.mlp.fc1.bias", "module.backbone.layers.3.blocks.0.mlp.fc2.weight", "module.backbone.layers.3.blocks.0.mlp.fc2.bias", "module.backbone.layers.3.blocks.1.norm1.weight", "module.backbone.layers.3.blocks.1.norm1.bias", "module.backbone.layers.3.blocks.1.attn.relative_position_bias_table", "module.backbone.layers.3.blocks.1.attn.relative_position_index", "module.backbone.layers.3.blocks.1.attn.qkv.weight", "module.backbone.layers.3.blocks.1.attn.qkv.bias", "module.backbone.layers.3.blocks.1.attn.proj.weight", "module.backbone.layers.3.blocks.1.attn.proj.bias", "module.backbone.layers.3.blocks.1.norm2.weight", "module.backbone.layers.3.blocks.1.norm2.bias", "module.backbone.layers.3.blocks.1.mlp.fc1.weight", "module.backbone.layers.3.blocks.1.mlp.fc1.bias", "module.backbone.layers.3.blocks.1.mlp.fc2.weight", "module.backbone.layers.3.blocks.1.mlp.fc2.bias", "module.backbone.norm0.weight", "module.backbone.norm0.bias", "module.backbone.norm1.weight", "module.backbone.norm1.bias", "module.backbone.norm2.weight", "module.backbone.norm2.bias", "module.backbone.norm3.weight", "module.backbone.norm3.bias", "module.decode_head.conv_seg.weight", "module.decode_head.conv_seg.bias", "module.decode_head.linear_c4.proj.weight", "module.decode_head.linear_c4.proj.bias", "module.decode_head.linear_c4.norm.weight", "module.decode_head.linear_c4.norm.bias", "module.decode_head.linear_c3.proj.weight", "module.decode_head.linear_c3.proj.bias", "module.decode_head.linear_c3.norm.weight", "module.decode_head.linear_c3.norm.bias", "module.decode_head.linear_c2.proj.weight", "module.decode_head.linear_c2.proj.bias", "module.decode_head.linear_c2.norm.weight", "module.decode_head.linear_c2.norm.bias", "module.decode_head.linear_c1.proj.weight", "module.decode_head.linear_c1.proj.bias", "module.decode_head.linear_c1.norm.weight", "module.decode_head.linear_c1.norm.bias", "module.decode_head.linear_c3_out.proj.weight", "module.decode_head.linear_c3_out.proj.bias", "module.decode_head.linear_c3_out.norm.weight", "module.decode_head.linear_c3_out.norm.bias", "module.decode_head.linear_c2_out.proj.weight", "module.decode_head.linear_c2_out.proj.bias", "module.decode_head.linear_c2_out.norm.weight", "module.decode_head.linear_c2_out.norm.bias", "module.decode_head.linear_c1_out.proj.weight", "module.decode_head.linear_c1_out.proj.bias", "module.decode_head.linear_c1_out.norm.weight", "module.decode_head.linear_c1_out.norm.bias", "module.decode_head.linear_fuse.proj.weight", "module.decode_head.linear_fuse.proj.bias", "module.decode_head.linear_fuse.norm.weight", "module.decode_head.linear_fuse.norm.bias", "module.decode_head.linear_pred.proj.weight", "module.decode_head.linear_pred.proj.bias", "module.auxiliary_head.conv_seg.weight", "module.auxiliary_head.conv_seg.bias", "module.auxiliary_head.convs.0.conv.weight", "module.auxiliary_head.convs.0.bn.weight", "module.auxiliary_head.convs.0.bn.bias", "module.auxiliary_head.convs.0.bn.running_mean", "module.auxiliary_head.convs.0.bn.running_var". 
        Unexpected key(s) in state_dict: "module.module.backbone.patch_embed.proj.weight", "module.module.backbone.patch_embed.proj.bias", "module.module.backbone.patch_embed.norm.weight", "module.module.backbone.patch_embed.norm.bias", "module.module.backbone.layers.0.blocks.0.norm1.weight", "module.module.backbone.layers.0.blocks.0.norm1.bias", "module.module.backbone.layers.0.blocks.0.attn.relative_position_bias_table", "module.module.backbone.layers.0.blocks.0.attn.relative_position_index", "module.module.backbone.layers.0.blocks.0.attn.qkv.weight", "module.module.backbone.layers.0.blocks.0.attn.qkv.bias", "module.module.backbone.layers.0.blocks.0.attn.proj.weight", "module.module.backbone.layers.0.blocks.0.attn.proj.bias", "module.module.backbone.layers.0.blocks.0.norm2.weight", "module.module.backbone.layers.0.blocks.0.norm2.bias", "module.module.backbone.layers.0.blocks.0.mlp.fc1.weight", "module.module.backbone.layers.0.blocks.0.mlp.fc1.bias", "module.module.backbone.layers.0.blocks.0.mlp.fc2.weight", "module.module.backbone.layers.0.blocks.0.mlp.fc2.bias", "module.module.backbone.layers.0.blocks.1.norm1.weight", "module.module.backbone.layers.0.blocks.1.norm1.bias", "module.module.backbone.layers.0.blocks.1.attn.relative_position_bias_table", "module.module.backbone.layers.0.blocks.1.attn.relative_position_index", "module.module.backbone.layers.0.blocks.1.attn.qkv.weight", "module.module.backbone.layers.0.blocks.1.attn.qkv.bias", "module.module.backbone.layers.0.blocks.1.attn.proj.weight", "module.module.backbone.layers.0.blocks.1.attn.proj.bias", "module.module.backbone.layers.0.blocks.1.norm2.weight", "module.module.backbone.layers.0.blocks.1.norm2.bias", "module.module.backbone.layers.0.blocks.1.mlp.fc1.weight", "module.module.backbone.layers.0.blocks.1.mlp.fc1.bias", "module.module.backbone.layers.0.blocks.1.mlp.fc2.weight", "module.module.backbone.layers.0.blocks.1.mlp.fc2.bias", "module.module.backbone.layers.0.downsample.reduction.weight", "module.module.backbone.layers.0.downsample.norm.weight", "module.module.backbone.layers.0.downsample.norm.bias", "module.module.backbone.layers.1.blocks.0.norm1.weight", "module.module.backbone.layers.1.blocks.0.norm1.bias", "module.module.backbone.layers.1.blocks.0.attn.relative_position_bias_table", "module.module.backbone.layers.1.blocks.0.attn.relative_position_index", "module.module.backbone.layers.1.blocks.0.attn.qkv.weight", "module.module.backbone.layers.1.blocks.0.attn.qkv.bias", "module.module.backbone.layers.1.blocks.0.attn.proj.weight", "module.module.backbone.layers.1.blocks.0.attn.proj.bias", "module.module.backbone.layers.1.blocks.0.norm2.weight", "module.module.backbone.layers.1.blocks.0.norm2.bias", "module.module.backbone.layers.1.blocks.0.mlp.fc1.weight", "module.module.backbone.layers.1.blocks.0.mlp.fc1.bias", "module.module.backbone.layers.1.blocks.0.mlp.fc2.weight", "module.module.backbone.layers.1.blocks.0.mlp.fc2.bias", "module.module.backbone.layers.1.blocks.1.norm1.weight", "module.module.backbone.layers.1.blocks.1.norm1.bias", "module.module.backbone.layers.1.blocks.1.attn.relative_position_bias_table", "module.module.backbone.layers.1.blocks.1.attn.relative_position_index", "module.module.backbone.layers.1.blocks.1.attn.qkv.weight", "module.module.backbone.layers.1.blocks.1.attn.qkv.bias", "module.module.backbone.layers.1.blocks.1.attn.proj.weight", "module.module.backbone.layers.1.blocks.1.attn.proj.bias", "module.module.backbone.layers.1.blocks.1.norm2.weight", "module.module.backbone.layers.1.blocks.1.norm2.bias", "module.module.backbone.layers.1.blocks.1.mlp.fc1.weight", "module.module.backbone.layers.1.blocks.1.mlp.fc1.bias", "module.module.backbone.layers.1.blocks.1.mlp.fc2.weight", "module.module.backbone.layers.1.blocks.1.mlp.fc2.bias", "module.module.backbone.layers.1.downsample.reduction.weight", "module.module.backbone.layers.1.downsample.norm.weight", "module.module.backbone.layers.1.downsample.norm.bias", "module.module.backbone.layers.2.blocks.0.norm1.weight", "module.module.backbone.layers.2.blocks.0.norm1.bias", "module.module.backbone.layers.2.blocks.0.attn.relative_position_bias_table", "module.module.backbone.layers.2.blocks.0.attn.relative_position_index", "module.module.backbone.layers.2.blocks.0.attn.qkv.weight", "module.module.backbone.layers.2.blocks.0.attn.qkv.bias", "module.module.backbone.layers.2.blocks.0.attn.proj.weight", "module.module.backbone.layers.2.blocks.0.attn.proj.bias", "module.module.backbone.layers.2.blocks.0.norm2.weight", "module.module.backbone.layers.2.blocks.0.norm2.bias", "module.module.backbone.layers.2.blocks.0.mlp.fc1.weight", "module.module.backbone.layers.2.blocks.0.mlp.fc1.bias", "module.module.backbone.layers.2.blocks.0.mlp.fc2.weight", "module.module.backbone.layers.2.blocks.0.mlp.fc2.bias", "module.module.backbone.layers.2.blocks.1.norm1.weight", "module.module.backbone.layers.2.blocks.1.norm1.bias", "module.module.backbone.layers.2.blocks.1.attn.relative_position_bias_table", "module.module.backbone.layers.2.blocks.1.attn.relative_position_index", "module.module.backbone.layers.2.blocks.1.attn.qkv.weight", "module.module.backbone.layers.2.blocks.1.attn.qkv.bias", "module.module.backbone.layers.2.blocks.1.attn.proj.weight", "module.module.backbone.layers.2.blocks.1.attn.proj.bias", "module.module.backbone.layers.2.blocks.1.norm2.weight", "module.module.backbone.layers.2.blocks.1.norm2.bias", "module.module.backbone.layers.2.blocks.1.mlp.fc1.weight", "module.module.backbone.layers.2.blocks.1.mlp.fc1.bias", "module.module.backbone.layers.2.blocks.1.mlp.fc2.weight", "module.module.backbone.layers.2.blocks.1.mlp.fc2.bias", "module.module.backbone.layers.2.blocks.2.norm1.weight", "module.module.backbone.layers.2.blocks.2.norm1.bias", "module.module.backbone.layers.2.blocks.2.attn.relative_position_bias_table", "module.module.backbone.layers.2.blocks.2.attn.relative_position_index", "module.module.backbone.layers.2.blocks.2.attn.qkv.weight", "module.module.backbone.layers.2.blocks.2.attn.qkv.bias", "module.module.backbone.layers.2.blocks.2.attn.proj.weight", "module.module.backbone.layers.2.blocks.2.attn.proj.bias", "module.module.backbone.layers.2.blocks.2.norm2.weight", "module.module.backbone.layers.2.blocks.2.norm2.bias", "module.module.backbone.layers.2.blocks.2.mlp.fc1.weight", "module.module.backbone.layers.2.blocks.2.mlp.fc1.bias", "module.module.backbone.layers.2.blocks.2.mlp.fc2.weight", "module.module.backbone.layers.2.blocks.2.mlp.fc2.bias", "module.module.backbone.layers.2.blocks.3.norm1.weight", "module.module.backbone.layers.2.blocks.3.norm1.bias", "module.module.backbone.layers.2.blocks.3.attn.relative_position_bias_table", "module.module.backbone.layers.2.blocks.3.attn.relative_position_index", "module.module.backbone.layers.2.blocks.3.attn.qkv.weight", "module.module.backbone.layers.2.blocks.3.attn.qkv.bias", "module.module.backbone.layers.2.blocks.3.attn.proj.weight", "module.module.backbone.layers.2.blocks.3.attn.proj.bias", "module.module.backbone.layers.2.blocks.3.norm2.weight", "module.module.backbone.layers.2.blocks.3.norm2.bias", "module.module.backbone.layers.2.blocks.3.mlp.fc1.weight", "module.module.backbone.layers.2.blocks.3.mlp.fc1.bias", "module.module.backbone.layers.2.blocks.3.mlp.fc2.weight", "module.module.backbone.layers.2.blocks.3.mlp.fc2.bias", "module.module.backbone.layers.2.blocks.4.norm1.weight", "module.module.backbone.layers.2.blocks.4.norm1.bias", "module.module.backbone.layers.2.blocks.4.attn.relative_position_bias_table", "module.module.backbone.layers.2.blocks.4.attn.relative_position_index", "module.module.backbone.layers.2.blocks.4.attn.qkv.weight", "module.module.backbone.layers.2.blocks.4.attn.qkv.bias", "module.module.backbone.layers.2.blocks.4.attn.proj.weight", "module.module.backbone.layers.2.blocks.4.attn.proj.bias", "module.module.backbone.layers.2.blocks.4.norm2.weight", "module.module.backbone.layers.2.blocks.4.norm2.bias", "module.module.backbone.layers.2.blocks.4.mlp.fc1.weight", "module.module.backbone.layers.2.blocks.4.mlp.fc1.bias", "module.module.backbone.layers.2.blocks.4.mlp.fc2.weight", "module.module.backbone.layers.2.blocks.4.mlp.fc2.bias", "module.module.backbone.layers.2.blocks.5.norm1.weight", "module.module.backbone.layers.2.blocks.5.norm1.bias", "module.module.backbone.layers.2.blocks.5.attn.relative_position_bias_table", "module.module.backbone.layers.2.blocks.5.attn.relative_position_index", "module.module.backbone.layers.2.blocks.5.attn.qkv.weight", "module.module.backbone.layers.2.blocks.5.attn.qkv.bias", "module.module.backbone.layers.2.blocks.5.attn.proj.weight", "module.module.backbone.layers.2.blocks.5.attn.proj.bias", "module.module.backbone.layers.2.blocks.5.norm2.weight", "module.module.backbone.layers.2.blocks.5.norm2.bias", "module.module.backbone.layers.2.blocks.5.mlp.fc1.weight", "module.module.backbone.layers.2.blocks.5.mlp.fc1.bias", "module.module.backbone.layers.2.blocks.5.mlp.fc2.weight", "module.module.backbone.layers.2.blocks.5.mlp.fc2.bias", "module.module.backbone.layers.2.downsample.reduction.weight", "module.module.backbone.layers.2.downsample.norm.weight", "module.module.backbone.layers.2.downsample.norm.bias", "module.module.backbone.layers.3.blocks.0.norm1.weight", "module.module.backbone.layers.3.blocks.0.norm1.bias", "module.module.backbone.layers.3.blocks.0.attn.relative_position_bias_table", "module.module.backbone.layers.3.blocks.0.attn.relative_position_index", "module.module.backbone.layers.3.blocks.0.attn.qkv.weight", "module.module.backbone.layers.3.blocks.0.attn.qkv.bias", "module.module.backbone.layers.3.blocks.0.attn.proj.weight", "module.module.backbone.layers.3.blocks.0.attn.proj.bias", "module.module.backbone.layers.3.blocks.0.norm2.weight", "module.module.backbone.layers.3.blocks.0.norm2.bias", "module.module.backbone.layers.3.blocks.0.mlp.fc1.weight", "module.module.backbone.layers.3.blocks.0.mlp.fc1.bias", "module.module.backbone.layers.3.blocks.0.mlp.fc2.weight", "module.module.backbone.layers.3.blocks.0.mlp.fc2.bias", "module.module.backbone.layers.3.blocks.1.norm1.weight", "module.module.backbone.layers.3.blocks.1.norm1.bias", "module.module.backbone.layers.3.blocks.1.attn.relative_position_bias_table", "module.module.backbone.layers.3.blocks.1.attn.relative_position_index", "module.module.backbone.layers.3.blocks.1.attn.qkv.weight", "module.module.backbone.layers.3.blocks.1.attn.qkv.bias", "module.module.backbone.layers.3.blocks.1.attn.proj.weight", "module.module.backbone.layers.3.blocks.1.attn.proj.bias", "module.module.backbone.layers.3.blocks.1.norm2.weight", "module.module.backbone.layers.3.blocks.1.norm2.bias", "module.module.backbone.layers.3.blocks.1.mlp.fc1.weight", "module.module.backbone.layers.3.blocks.1.mlp.fc1.bias", "module.module.backbone.layers.3.blocks.1.mlp.fc2.weight", "module.module.backbone.layers.3.blocks.1.mlp.fc2.bias", "module.module.backbone.norm0.weight", "module.module.backbone.norm0.bias", "module.module.backbone.norm1.weight", "module.module.backbone.norm1.bias", "module.module.backbone.norm2.weight", "module.module.backbone.norm2.bias", "module.module.backbone.norm3.weight", "module.module.backbone.norm3.bias", "module.module.decode_head.conv_seg.weight", "module.module.decode_head.conv_seg.bias", "module.module.decode_head.linear_c4.proj.weight", "module.module.decode_head.linear_c4.proj.bias", "module.module.decode_head.linear_c4.norm.weight", "module.module.decode_head.linear_c4.norm.bias", "module.module.decode_head.linear_c3.proj.weight", "module.module.decode_head.linear_c3.proj.bias", "module.module.decode_head.linear_c3.norm.weight", "module.module.decode_head.linear_c3.norm.bias", "module.module.decode_head.linear_c2.proj.weight", "module.module.decode_head.linear_c2.proj.bias", "module.module.decode_head.linear_c2.norm.weight", "module.module.decode_head.linear_c2.norm.bias", "module.module.decode_head.linear_c1.proj.weight", "module.module.decode_head.linear_c1.proj.bias", "module.module.decode_head.linear_c1.norm.weight", "module.module.decode_head.linear_c1.norm.bias", "module.module.decode_head.linear_c3_out.proj.weight", "module.module.decode_head.linear_c3_out.proj.bias", "module.module.decode_head.linear_c3_out.norm.weight", "module.module.decode_head.linear_c3_out.norm.bias", "module.module.decode_head.linear_c2_out.proj.weight", "module.module.decode_head.linear_c2_out.proj.bias", "module.module.decode_head.linear_c2_out.norm.weight", "module.module.decode_head.linear_c2_out.norm.bias", "module.module.decode_head.linear_c1_out.proj.weight", "module.module.decode_head.linear_c1_out.proj.bias", "module.module.decode_head.linear_c1_out.norm.weight", "module.module.decode_head.linear_c1_out.norm.bias", "module.module.decode_head.linear_fuse.proj.weight", "module.module.decode_head.linear_fuse.proj.bias", "module.module.decode_head.linear_fuse.norm.weight", "module.module.decode_head.linear_fuse.norm.bias", "module.module.decode_head.linear_pred.proj.weight", "module.module.decode_head.linear_pred.proj.bias", "module.module.auxiliary_head.conv_seg.weight", "module.module.auxiliary_head.conv_seg.bias", "module.module.auxiliary_head.convs.0.conv.weight", "module.module.auxiliary_head.convs.0.bn.weight", "module.module.auxiliary_head.convs.0.bn.bias", "module.module.auxiliary_head.convs.0.bn.running_mean", "module.module.auxiliary_head.convs.0.bn.running_var", "module.module.auxiliary_head.convs.0.bn.num_batches_tracked".

我仍然使用了dataloader加载数据,就是不知道怎么对权重进行修改

wondering516 avatar Aug 13 '22 06:08 wondering516

weights文件夹中我只保留了在我的数据集上训练好的best_weight.pkl,调用GPU测试的指标和效果都很不错,目前想在PC上预测单张照片,没有GPU,于是我写了predict_single.py文件,想用CPU执行预测,训练图片的HW都是384,其他的参数我都写死在MultiEvalModule里了,不需要通过命令行读取参数,请求大佬帮我看看我的代码哪里有问题。

# time: 2022/8/11 14:55
# author: CoffeeChicken
# description : 预测单张图片

import os, warnings
import torch
import cv2
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import math
import logging
from torchvision import transforms
import torch.utils.data as data
from models.swinT import swin_tiny as swinT
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import os

# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '39999'

# don't know
up_kwargs = {'mode': 'bilinear', 'align_corners': False}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __init__(self, add_edge=True):
        """imagenet normalize"""
        self.normalize = transforms.Normalize((.485, .456, .406), (.229, .224, .225))
        self.add_edge = add_edge

    def get_edge(self, img, edge_width=3):
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        gray = cv2.GaussianBlur(gray, (11, 11), 0)
        edge = cv2.Canny(gray, 50, 150)
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (edge_width, edge_width))
        edge = cv2.dilate(edge, kernel)
        edge = edge / 255
        edge = torch.from_numpy(edge).unsqueeze(dim=0).float()

        return edge

    def __call__(self, sample):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W

        print("sample", sample)
        img = sample['image']
        print(img.shape)

        # todo 修改括号
        img = np.array(img).astype(np.float32).transpose((2, 0, 1))

        img = torch.from_numpy(img).float().div(255)
        img = self.normalize(img)

        if self.add_edge:
            edge = self.get_edge(sample['image'])
            img = img + edge
        print("to tensor 后的图片shape", img.shape)

        return {'image': img}


def transform(sample):
    """
    :param sample:
    :return: 返回张量格式的数据
    """
    composed_transforms = transforms.Compose([ToTensor(add_edge=False), ])
    return composed_transforms(sample)


def module_inference(module, image, flip=True):
    if flip:
        h_img = h_flip_image(image)
        v_img = v_flip_image(image)
        img = torch.cat([image, h_img, v_img], dim=0)
        cat_output = module(img)
        if isinstance(cat_output, (list, tuple)):
            cat_output = cat_output[0]
        output, h_output, v_output = cat_output.chunk(3, dim=0)
        output = output + h_flip_image(h_output) + v_flip_image(v_output)
    else:
        output = module(image)
        if isinstance(output, (list, tuple)):
            output = output[0]

    return output


def resize_image(img, h, w, **up_kwargs):
    return F.upsample(img, (h, w), **up_kwargs)


def crop_image(img, h0, h1, w0, w1):
    return img[:, :, h0:h1, w0:w1]


def h_flip_image(img):
    assert (img.dim() == 4)
    with torch.cuda.device_of(img):
        idx = torch.arange(img.size(3) - 1, -1, -1).type_as(img).long()
    return img.index_select(3, idx)


def v_flip_image(img):
    assert (img.dim() == 4)
    with torch.cuda.device_of(img):
        idx = torch.arange(img.size(3) - 1, -1, -1).type_as(img).long()
    return img.index_select(2, idx)


def hv_flip_image(img):
    assert (img.dim() == 4)
    with torch.cuda.device_of(img):
        idx = torch.arange(img.size(3) - 1, -1, -1).type_as(img).long()
    img = img.index_select(3, idx)
    return img.index_select(2, idx)


def pad_image(img, crop_size):
    """crop_size could be list:[h, w] or int"""
    b, c, h, w = img.size()
    # assert(c==3)
    if len(crop_size) > 1:
        padh = crop_size[0] - h if h < crop_size[0] else 0
        padw = crop_size[1] - w if w < crop_size[1] else 0
    else:
        padh = crop_size - h if h < crop_size else 0
        padw = crop_size - w if w < crop_size else 0
    # pad_values = -np.array(mean) / np.array(std)
    img_pad = img.new().resize_(b, c, h + padh, w + padw)
    # for i in range(c):
    # note that pytorch pad params is in reversed orders
    min_padh = min(padh, h)
    min_padw = min(padw, w)
    if padw < w and padh < h:
        img_pad[:, :, :, :] = F.pad(img[:, :, :, :], (0, padw, 0, padh), mode='reflect')
    else:
        img_pad[:, :, 0:h + min_padh - 1, 0:w + min_padw - 1] = \
            F.pad(img[:, :, :, :], (0, min_padw - 1, 0, min_padh - 1), mode='reflect')

        img_pad[:, :, :, :] = F.pad(img_pad[:, :, 0:h + min_padh - 1, 0:w + min_padw - 1],
                                    (0, padw - min_padw + 1, 0, padh - min_padh + 1), mode='constant', value=0)
    if len(crop_size) > 1:
        assert (img_pad.size(2) >= crop_size[0] and img_pad.size(3) >= crop_size[1])
    else:
        assert (img_pad.size(2) >= crop_size and img_pad.size(3) >= crop_size)
    return img_pad


class MultiEvalModule(nn.Module):
    """Multi-size Segmentation Evaluator"""

    def __init__(self, module, nclass, device_ids=None, flip=True, save_gpu_memory=False,
                 scales=[1.0], get_batch=1, crop_size=[512, 512], stride_rate=1 / 2):

        super(MultiEvalModule, self).__init__()
        self.module = module
        self.devices_ids = device_ids
        self.nclass = nclass
        self.crop_size = np.array(crop_size)
        self.scales = scales
        self.flip = flip
        self.get_batch = get_batch
        self.stride_rate = stride_rate
        self.save_gpu_memory = save_gpu_memory

    def forward(self, image):
        """Mult-size Evaluation"""
        # only single image is supported for evaluation
        batch, _, h, w = image.size()
        # assert(batch == 1)
        stride_rate = self.stride_rate
        with torch.cuda.device_of(image):
            if self.save_gpu_memory:
                scores = image.new().resize_(batch, self.nclass, h, w).zero_().cpu()
            else:
                scores = image.new().resize_(batch, self.nclass, h, w).zero_().cuda()

        for scale in self.scales:
            crop_size = self.crop_size
            stride = (crop_size * stride_rate).astype(np.int)

            if h > w:
                long_size = int(math.ceil(h * scale))
                height = long_size
                width = int(1.0 * w * long_size / h + 0.5)
                short_size = width
            else:
                long_size = int(math.ceil(w * scale))
                width = long_size
                height = int(1.0 * h * long_size / w + 0.5)
                short_size = height

            # resize image to current size
            cur_img = resize_image(image, height, width, **up_kwargs)
            if long_size <= np.max(crop_size):
                pad_img = pad_image(cur_img, crop_size)
                outputs = module_inference(self.module, pad_img, self.flip)
                outputs = crop_image(outputs, 0, height, 0, width)

            else:
                if short_size < np.min(crop_size):
                    # pad if needed
                    pad_img = pad_image(cur_img, crop_size)
                else:
                    pad_img = cur_img
                _, _, ph, pw = pad_img.size()
                # assert(ph >= height and pw >= width)
                # grid forward and normalize
                h_grids = int(math.ceil(1.0 * (ph - crop_size[0]) / stride[0])) + 1
                w_grids = int(math.ceil(1.0 * (pw - crop_size[1]) / stride[1])) + 1
                with torch.cuda.device_of(image):
                    if self.save_gpu_memory:
                        outputs = image.new().resize_(batch, self.nclass, ph, pw).zero_().cpu()
                        count_norm = image.new().resize_(batch, 1, ph, pw).zero_().cpu()
                    else:
                        outputs = image.new().resize_(batch, self.nclass, ph, pw).zero_().cuda()
                        count_norm = image.new().resize_(batch, 1, ph, pw).zero_().cuda()
                # grid evaluation
                location = []
                batch_size = []
                pad_img = pad_image(pad_img, [ph + crop_size[0], pw + crop_size[1]])  # expand pad_image

                for idh in range(h_grids):
                    for idw in range(w_grids):
                        h0 = idh * stride[0]
                        w0 = idw * stride[1]
                        h1 = min(h0 + crop_size[0], ph)
                        w1 = min(w0 + crop_size[1], pw)

                        crop_img = crop_image(pad_img, h0, h0 + crop_size[0], w0, w0 + crop_size[1])
                        # pad if needed
                        pad_crop_img = pad_image(crop_img, crop_size)
                        size_h, size_w = pad_crop_img.shape[-2:]
                        pad_crop_img = resize_image(pad_crop_img, crop_size[0], crop_size[1], **up_kwargs)
                        if self.get_batch > 1:
                            location.append([h0, w0, h1, w1])
                            batch_size.append(pad_crop_img)
                            if len(location) == self.get_batch or (idh + idw + 2) == (h_grids + w_grids):
                                batch_size = torch.cat(batch_size, dim=0).cuda()
                                location = np.array(location)
                                output = module_inference(self.module, batch_size, self.flip)
                                output = output.detach()
                                output = resize_image(output, size_h, size_w, **up_kwargs)
                                if self.save_gpu_memory:
                                    output = output.detach().cpu()  # to save gpu memory
                                else:
                                    output = output.detach()
                                for i in range(batch_size.shape[0]):
                                    outputs[:, :, location[i][0]:location[i][2], location[i][1]:location[i][3]] += \
                                        crop_image(output[i, ...].unsqueeze(dim=0), 0, location[i][2] - location[i][0],
                                                   0, location[i][3] - location[i][1])
                                    count_norm[:, :, location[i][0]:location[i][2], location[i][1]:location[i][3]] += 1
                                location = []
                                batch_size = []
                        else:
                            output = module_inference(self.module, pad_crop_img, self.flip)
                            if self.save_gpu_memory:
                                output = output.detach().cpu()  # to save gpu memory
                            else:
                                output = output.detach()
                            output = resize_image(output, size_h, size_w, **up_kwargs)
                            outputs[:, :, h0:h1, w0:w1] += crop_image(output,
                                                                      0, h1 - h0, 0, w1 - w0)
                            count_norm[:, :, h0:h1, w0:w1] += 1
                assert ((count_norm == 0).sum() == 0)
                outputs = outputs / count_norm
                outputs = outputs[:, :, :height, :width]
            score = resize_image(outputs, h, w, **up_kwargs)
            scores += score
        return scores


def label_to_RGB(image):
    """
    :param image: 将传入的label图像转化为RGB形式
    :return:
    """
    RGB = np.zeros(shape=[image.shape[0], image.shape[1], 3], dtype=np.uint8)
    index = image == 0
    RGB[index] = np.array([255, 255, 255])
    index = image == 1
    RGB[index] = np.array([0, 0, 255])
    index = image == 2
    RGB[index] = np.array([0, 255, 255])
    index = image == 3
    RGB[index] = np.array([0, 255, 0])
    index = image == 4
    RGB[index] = np.array([255, 255, 0])
    index = image == 5
    RGB[index] = np.array([255, 0, 0])
    return RGB


class load_image(data.Dataset):
    """
        仅仅读取需要预测的图片, 返回transform之后的数据
    """

    def __init__(self, images_dir=""):
        super(load_image, self).__init__()
        self.images = []
        self.names = []
        self.image_dir = images_dir

        if images_dir.endswith('tif') or images_dir.endswith('png'):
            img = cv2.imread(images_dir)
            self.images.append(img)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        """
        :param index: 数据集中的索引
        :return: 返回字典类型的文件,如果是训练集则返回包含增强后的image和label
        """
        sample = {'image': self.images[index]}
        sample = transform(sample)
        return sample

    def __str__(self):
        return 'dataset names: {}'.format(self.names)


def predict_image(model, weight_path, image_loader):
    """
    在预测之后对预测结果保存
    :param image_loader: dataloader格式(只有一张图片)
    :param model: 模型
    :param weight_path: 训练权重保存位置
    :return:
    """
    # 确认种类数
    nclasses = 2

    model = MultiEvalModule(model, nclass=nclasses, flip=True, scales=[0.5, 0.75, 1.0, 1.25, 1.5],
                            save_gpu_memory=False, crop_size=[384, 384], stride_rate=1 / 2, get_batch=1)
    model.eval()
    with torch.no_grad():
        model_state_file = weight_path
        if os.path.isfile(model_state_file):
            print('loading checkpoint successfully')
            checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage)
            # checkpoint = torch.load(model_state_file, map_location=lambda storage, loc: storage)
            if 'state_dict' in checkpoint:
                checkpoint = checkpoint['state_dict']
            elif 'model' in checkpoint:
                checkpoint = checkpoint['model']
            else:
                checkpoint = checkpoint
            checkpoint = {k: v for k, v in checkpoint.items() if not 'n_averaged' in k}
            checkpoint = {k.replace('model.', 'module.'): v for k, v in checkpoint.items()}
            # 加载处理后的模型
            model.load_state_dict(checkpoint)
        else:
            warnings.warn('weight is not existed !!!"')

        names = "test_output3.png"

        # image_loader 中只有一个图片
        for sample in image_loader:
            images = sample['image']
            print(images.shape)
            # todo 这里可否修改成.cpu()
            images = images.cuda()
            logits = model(images)
            
            logits = logits.argmax(dim=1)
            logits = logits.cpu().detach().numpy()

            vis_logits = label_to_RGB(logits.squeeze())[:, :, ::-1]
        
            cv2.imwrite(names, vis_logits)


if __name__ == '__main__':

    # cudnn.benchmark = True
    # cudnn.deterministic = False
    # cudnn.enabled = True
    # torch.cuda.set_device(0)
    # torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=1, rank=0)

    weight_dir = "weights/best_weight.pkl"
    image = "test.png"
    model = swinT(nclass=2, pretrained=False, aux=True, head="mlphead")
    print("swinT OK")

    # model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    
    device = torch.device('cpu')
    model = model.to(device)

    # model = nn.parallel.DistributedDataParallel(model, device_ids=[0, ], output_device=0)
    # print("syn OK")


    loaded_image = load_image(image)

    dataloader_val_full = DataLoader(
        loaded_image,
        batch_size=1,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        sampler=None)

    print("device OK")

    predict_image(model, weight_dir, dataloader_val_full)
    print("predict OK")

wondering516 avatar Aug 13 '22 06:08 wondering516

权重其实就是键值对,你可以按下面这样改一下读进来的key。根据你实际的key和模型需要的key,对model,module这些前缀进行修改。 checkpoint = {k.replace('model.', 'module.'): v for k, v in checkpoint.items()}

zyxu1996 avatar Aug 13 '22 06:08 zyxu1996

看起来你需要的是module.backbone.xx,但是读进来是module.module.backbone.xx。 是不是原本这句话,checkpoint = {k.replace('model.', 'module.'): v for k, v in checkpoint.items()},把model变成module了,试试换成checkpoint = {k.replace('model.', ''): v for k, v in checkpoint.items()}

zyxu1996 avatar Aug 13 '22 06:08 zyxu1996

谢谢大佬,我按照您的建议改好了,非常感谢您这两天对我的帮助

wondering516 avatar Aug 13 '22 07:08 wondering516