problem
After dividing the pixels of the binary image by 255 and setting num_class to 2 for training, I encountered an error when loading the model for segmentation:Missing key(s) in state_dict: "image_encoder.pos_embed", "image_encoder.patch_embed.proj.weight", "image_encoder.patch_embed.proj.bias", "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.0.norm2.weight", "image_encoder.blocks.0.norm2.bias", "image_encoder.blocks.0.mlp.lin1.weight", "image_encoder.blocks.0.mlp.lin1.bias", "image_encoder.blocks.0.mlp.lin2.weight", "image_encoder.blocks.0.mlp.lin2.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_h", "image_e............ Did I have a problem during my training?
After dividing the pixels of the binary image by 255 and setting num_class to 2 for training, I encountered an error when loading the model for segmentation:Missing key(s) in state_dict: "image_encoder.pos_embed", "image_encoder.patch_embed.proj.weight", "image_encoder.patch_embed.proj.bias", "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.0.norm2.weight", "image_encoder.blocks.0.norm2.bias", "image_encoder.blocks.0.mlp.lin1.weight", "image_encoder.blocks.0.mlp.lin1.bias", "image_encoder.blocks.0.mlp.lin2.weight", "image_encoder.blocks.0.mlp.lin2.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_h", "image_e............ Did I have a problem during my training?
Did your checkpoint is downloaded from the site. Please make sure the checkpoint is correct.
yes im sure
yes im sure
When the checkpoint is not correct, then this wrong occurred. Could you provide the command and the whole log? Thank you very much.
python train_learnable_sam.py --image C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\image_cut --mask_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\gt_cut --model_name vit_b --checkpoint C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth --save_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts --lr 0.05 --mix_precision --optimizer sgd
python train_learnable_sam.py --image C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\image_cut --mask_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\gt_cut --model_name vit_b --checkpoint C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth --save_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts --lr 0.05 --mix_precision --optimizer sgd
Sorry, please provide the log of the training. You can try this way to check the checkpoint.
python -c "import torch; print(torch.load('C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth ').keys())"
If the checkpoint is correct. I suggest you to change the command to
python train_learnable_sam.py --image train/image_cut --mask_path train/gt_cut --model_name vit_b --checkpoint ckpts/sam_vit_b_01ec64.pth --save_path ckpts --lr 0.05 --mix_precision --optimizer sgd
You maybe choice wrong model.If you trained SAM model ,Use PromptSAM besides PromptDiNo in the model for segmentation.
python train_learnable_sam.py --image C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\image_cut --mask_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\train\gt_cut --model_name vit_b --checkpoint C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth --save_path C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts --lr 0.05 --mix_precision --optimizer sgdSorry, please provide the log of the training. You can try this way to check the checkpoint.对不起,请提供培训日志。您可以尝试这种方式来检查检查点。
python -c "import torch; print(torch.load('C:\Users\Duan\Desktop\LearnablePromptSAM-main\ckpts\sam_vit_b_01ec64.pth ').keys())" If the checkpoint is correct. I suggest you to change the command to如果检查点是正确的。我建议你把命令改成
python train_learnable_sam.py --image train/image_cut --mask_path train/gt_cut --model_name vit_b --checkpoint ckpts/sam_vit_b_01ec64.pth --save_path ckpts --lr 0.05 --mix_precision --optimizer sgd
I'm having the same problem as the blogger, which seems to be due to a problem with the model test code? You answered the question in issue 11 about how to use models for testing, but I'm very sorry, I didn't really see what it means? I tried to write a test file and then ran into the appeal problem, not sure if it's because of an error in the test code? Can you please correct me!
model = PromptSAM("vit_b", checkpoint="/home/b/桌面/LearnablePromptSAM-main/save_model/sam_vit_b_prompt.pth", num_classes=2, reduction=4, upsample_times=2, groups=4)
img = Image.open("/home/b/桌面/mt/struct-uncertainty-main/DRIVE数据集/total_img/img_folder/01_test.tif").convert("RGB")
img = np.asarray(img)
img_size = 1024
pixel_mean=[0.5]*3
pixel_std=[0.5]*3
transform = Compose(
[
ColorJitter(),
VerticalFlip(),
HorizontalFlip(),
Resize(img_size, img_size),
Normalize(mean=pixel_mean, std=pixel_std)
]
)
x = transform(image=img)
pred = model(x)
pred = torch.softmax(pred, dim=1)
pred = torch.max(pred, dim=1)[1]
print(pred)
One of the checkpoints loaded by PromptSAM is the weights file saved during training (looking at the size it's about the same as the original SAM model, it seems to be storing all the layers, not just the trainable layers, is that right?)
According to the error message it is in the PromptSAM class in the first line of the above code : self.sam = sam_model_registrymodel_name
def build_sam_vit_b(checkpoint=None). return _build_sam(.......
This happens at the sam_model_registry, meaning that the sam_model_registry is unable to load a model whose structure has changed.