LearnablePromptSAM icon indicating copy to clipboard operation
LearnablePromptSAM copied to clipboard

problem

Open liuadan opened this issue 1 year ago • 7 comments

After dividing the pixels of the binary image by 255 and setting num_class to 2 for training, I encountered an error when loading the model for segmentation:Missing key(s) in state_dict: "image_encoder.pos_embed", "image_encoder.patch_embed.proj.weight", "image_encoder.patch_embed.proj.bias", "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.0.norm2.weight", "image_encoder.blocks.0.norm2.bias", "image_encoder.blocks.0.mlp.lin1.weight", "image_encoder.blocks.0.mlp.lin1.bias", "image_encoder.blocks.0.mlp.lin2.weight", "image_encoder.blocks.0.mlp.lin2.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_h", "image_e............ Did I have a problem during my training?

liuadan avatar Aug 02 '24 10:08 liuadan

After dividing the pixels of the binary image by 255 and setting num_class to 2 for training, I encountered an error when loading the model for segmentation:Missing key(s) in state_dict: "image_encoder.pos_embed", "image_encoder.patch_embed.proj.weight", "image_encoder.patch_embed.proj.bias", "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.0.norm2.weight", "image_encoder.blocks.0.norm2.bias", "image_encoder.blocks.0.mlp.lin1.weight", "image_encoder.blocks.0.mlp.lin1.bias", "image_encoder.blocks.0.mlp.lin2.weight", "image_encoder.blocks.0.mlp.lin2.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_h", "image_e............ Did I have a problem during my training?

Did your checkpoint is downloaded from the site. Please make sure the checkpoint is correct.

Qsingle avatar Aug 06 '24 06:08 Qsingle

yes im sure

liuadan avatar Aug 06 '24 06:08 liuadan

yes im sure

When the checkpoint is not correct, then this wrong occurred. Could you provide the command and the whole log? Thank you very much.

Qsingle avatar Aug 06 '24 07:08 Qsingle

python train_learnable_sam.py --image C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\image_cut --mask_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\gt_cut --model_name vit_b --checkpoint C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth --save_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts --lr 0.05 --mix_precision --optimizer sgd

liuadan avatar Aug 06 '24 07:08 liuadan

python train_learnable_sam.py --image C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\image_cut --mask_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\gt_cut --model_name vit_b --checkpoint C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth --save_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts --lr 0.05 --mix_precision --optimizer sgd

Sorry, please provide the log of the training. You can try this way to check the checkpoint.

python -c "import torch; print(torch.load('C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth ').keys())"

If the checkpoint is correct. I suggest you to change the command to

python train_learnable_sam.py --image train/image_cut --mask_path train/gt_cut --model_name vit_b --checkpoint ckpts/sam_vit_b_01ec64.pth --save_path ckpts --lr 0.05 --mix_precision --optimizer sgd

Qsingle avatar Aug 07 '24 02:08 Qsingle

You maybe choice wrong model.If you trained SAM model ,Use PromptSAM besides PromptDiNo in the model for segmentation.

CNwuyueyu avatar Oct 16 '24 15:10 CNwuyueyu

python train_learnable_sam.py --image C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\image_cut --mask_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\gt_cut --model_name vit_b --checkpoint C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth --save_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts --lr 0.05 --mix_precision --optimizer sgd

Sorry, please provide the log of the training. You can try this way to check the checkpoint.对不起,请提供培训日志。您可以尝试这种方式来检查检查点。

python -c "import torch; print(torch.load('C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth ').keys())" If the checkpoint is correct. I suggest you to change the command to如果检查点是正确的。我建议你把命令改成

python train_learnable_sam.py --image train/image_cut --mask_path train/gt_cut --model_name vit_b --checkpoint ckpts/sam_vit_b_01ec64.pth --save_path ckpts --lr 0.05 --mix_precision --optimizer sgd

I'm having the same problem as the blogger, which seems to be due to a problem with the model test code? You answered the question in issue 11 about how to use models for testing, but I'm very sorry, I didn't really see what it means? I tried to write a test file and then ran into the appeal problem, not sure if it's because of an error in the test code? Can you please correct me!

model = PromptSAM("vit_b", checkpoint="/home/b/桌面/LearnablePromptSAM-main/save_model/sam_vit_b_prompt.pth", num_classes=2, reduction=4, upsample_times=2, groups=4)
img = Image.open("/home/b/桌面/mt/struct-uncertainty-main/DRIVE数据集/total_img/img_folder/01_test.tif").convert("RGB")
img = np.asarray(img)
img_size = 1024
pixel_mean=[0.5]*3
pixel_std=[0.5]*3
transform = Compose(
    [
        ColorJitter(),
        VerticalFlip(),
        HorizontalFlip(),
        Resize(img_size, img_size),
        Normalize(mean=pixel_mean, std=pixel_std)
    ]
)
x = transform(image=img)
pred = model(x)
pred = torch.softmax(pred, dim=1)
pred = torch.max(pred, dim=1)[1]
print(pred)

One of the checkpoints loaded by PromptSAM is the weights file saved during training (looking at the size it's about the same as the original SAM model, it seems to be storing all the layers, not just the trainable layers, is that right?)

According to the error message it is in the PromptSAM class in the first line of the above code : self.sam = sam_model_registrymodel_name

def build_sam_vit_b(checkpoint=None). return _build_sam(.......

This happens at the sam_model_registry, meaning that the sam_model_registry is unable to load a model whose structure has changed.

zxcvbnmkj avatar Apr 07 '25 13:04 zxcvbnmkj