Code for segmenting a single image
With the original segment anything all I have to do to segment a single image is downloading the checkpoint and running
from segment_anything import SamPredictor, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)
How can I do something similar in FastSAM3D? I downloaded the FastSAM3D checkpoint and have a 3D image with prompts ready. How do use FastSAM3D to segment my image using my prompts?
You could use infer.sh, and modify the parameter as you want. vp means visualization path you want to output, tdp mean the image you want to used for segmenting.
python validation.py --seed 2023
-vp ./results/vis_sam_med3d
-tdp data/initial_test_dataset/total_segment -nc 1
Regarding this issue, I have my data set in the specified format and have saved the FastSam3D checkpoint locally. When I run infer.sh I experience errors when loading the model. They are regarding missing encoder blocks. Additionally, when looking through the validation.py file I notice many functions refer to tuning the model. In my case, I do not want to tune the model. I want to test the model's segmentation on my data. Any help with this would be greatly appreciated.
Here is my error message below.
RuntimeError: Error(s) in loading state_dict for Sam3D: Missing key(s) in state_dict: "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_d", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_d", "image_encoder.blocks.1.attn.rel_pos_h", "image_encoder.blocks.1.attn.rel_pos_w", "image_encoder.blocks.1.attn.qkv.weight", "image_encoder.blocks.1.attn.qkv.bias", "image_encoder.blocks.1.attn.proj.weight", "image_encoder.blocks.1.attn.proj.bias", "image_encoder.blocks.6.norm1.weight", "image_encoder.blocks.6.norm1.bias", "image_encoder.blocks.6.attn.rel_pos_d", "image_encoder.blocks.6.attn.rel_pos_h", "image_encoder.blocks.6.attn.rel_pos_w", "image_encoder.blocks.6.attn.qkv.weight", "image_encoder.blocks.6.attn.qkv.bias", "image_encoder.blocks.6.attn.proj.weight", "image_encoder.blocks.6.attn.proj.bias", "image_encoder.blocks.6.norm2.weight", "image_encoder.blocks.6.norm2.bias", "image_encoder.blocks.6.mlp.lin1.weight", "image_encoder.blocks.6.mlp.lin1.bias", "image_encoder.blocks.6.mlp.lin2.weight", "image_encoder.blocks.6.mlp.lin2.bias", "image_encoder.blocks.7.norm1.weight", "image_encoder.blocks.7.norm1.bias", "image_encoder.blocks.7.attn.rel_pos_d", "image_encoder.blocks.7.attn.rel_pos_h", "image_encoder.blocks.7.attn.rel_pos_w", "image_encoder.blocks.7.attn.qkv.weight", "image_encoder.blocks.7.attn.qkv.bias", "image_encoder.blocks.7.attn.proj.weight", "image_encoder.blocks.7.attn.proj.bias", "image_encoder.blocks.7.norm2.weight", "image_encoder.blocks.7.norm2.bias", "image_encoder.blocks.7.mlp.lin1.weight", "image_encoder.blocks.7.mlp.lin1.bias", "image_encoder.blocks.7.mlp.lin2.weight", "image_encoder.blocks.7.mlp.lin2.bias", "image_encoder.blocks.8.norm1.weight", "image_encoder.blocks.8.norm1.bias", "image_encoder.blocks.8.attn.rel_pos_d", "image_encoder.blocks.8.attn.rel_pos_h", "image_encoder.blocks.8.attn.rel_pos_w", "image_encoder.blocks.8.attn.qkv.weight", "image_encoder.blocks.8.attn.qkv.bias", "image_encoder.blocks.8.attn.proj.weight", "image_encoder.blocks.8.attn.proj.bias", "image_encoder.blocks.8.norm2.weight", "image_encoder.blocks.8.norm2.bias", "image_encoder.blocks.8.mlp.lin1.weight", "image_encoder.blocks.8.mlp.lin1.bias", "image_encoder.blocks.8.mlp.lin2.weight", "image_encoder.blocks.8.mlp.lin2.bias", "image_encoder.blocks.9.norm1.weight", "image_encoder.blocks.9.norm1.bias", "image_encoder.blocks.9.attn.rel_pos_d", "image_encoder.blocks.9.attn.rel_pos_h", "image_encoder.blocks.9.attn.rel_pos_w", "image_encoder.blocks.9.attn.qkv.weight", "image_encoder.blocks.9.attn.qkv.bias", "image_encoder.blocks.9.attn.proj.weight", "image_encoder.blocks.9.attn.proj.bias", "image_encoder.blocks.9.norm2.weight", "image_encoder.blocks.9.norm2.bias", "image_encoder.blocks.9.mlp.lin1.weight", "image_encoder.blocks.9.mlp.lin1.bias", "image_encoder.blocks.9.mlp.lin2.weight", "image_encoder.blocks.9.mlp.lin2.bias", "image_encoder.blocks.10.norm1.weight", "image_encoder.blocks.10.norm1.bias", "image_encoder.blocks.10.attn.rel_pos_d", "image_encoder.blocks.10.attn.rel_pos_h", "image_encoder.blocks.10.attn.rel_pos_w", "image_encoder.blocks.10.attn.qkv.weight", "image_encoder.blocks.10.attn.qkv.bias", "image_encoder.blocks.10.attn.proj.weight", "image_encoder.blocks.10.attn.proj.bias", "image_encoder.blocks.10.norm2.weight", "image_encoder.blocks.10.norm2.bias", "image_encoder.blocks.10.mlp.lin1.weight", "image_encoder.blocks.10.mlp.lin1.bias", "image_encoder.blocks.10.mlp.lin2.weight", "image_encoder.blocks.10.mlp.lin2.bias", "image_encoder.blocks.11.norm1.weight", "image_encoder.blocks.11.norm1.bias", "image_encoder.blocks.11.attn.rel_pos_d", "image_encoder.blocks.11.attn.rel_pos_h", "image_encoder.blocks.11.attn.rel_pos_w", "image_encoder.blocks.11.attn.qkv.weight", "image_encoder.blocks.11.attn.qkv.bias", "image_encoder.blocks.11.attn.proj.weight", "image_encoder.blocks.11.attn.proj.bias", "image_encoder.blocks.11.norm2.weight", "image_encoder.blocks.11.norm2.bias", "image_encoder.blocks.11.mlp.lin1.weight", "image_encoder.blocks.11.mlp.lin1.bias", "image_encoder.blocks.11.mlp.lin2.weight", "image_encoder.blocks.11.mlp.lin2.bias". size mismatch for image_encoder.blocks.2.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.2.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.2.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]).
I'm experiencing the same issue as @williamtrayn0r. There seems to be a mismatch between the architecture defined in build_sam3D.py and the checkpoint you provided. Could you help us resolve this?
Which model setting are you using? Can you show me your model setting in buildsam3d? @noahrecher
I am using the "vit_b_ori" ViT in buildsam3d and the latest checkpoints downloaded from your repository. Please let me know if you require further information. Thank you!
Can you try to make line 196 "skip layer = 2" in buildsam3d.py ?
I changed it but I'm still getting an error:
RuntimeError: Error(s) in loading state_dict for Sam3D: Missing key(s) in state_dict: "image_encoder.blocks.6.norm1.weight", "image_encoder.blocks.6.norm1.bias", "image_encoder.blocks.6.attn.rel_pos_d", "image_encoder.blocks.6.attn.rel_pos_h", "image_encoder.blocks.6.attn.rel_pos_w", "image_encoder.blocks.6.attn.qkv.weight", "image_encoder.blocks.6.attn.qkv.bias", "image_encoder.blocks.6.attn.proj.weight", "image_encoder.blocks.6.attn.proj.bias", "image_encoder.blocks.6.norm2.weight", "image_encoder.blocks.6.norm2.bias", "image_encoder.blocks.6.mlp.lin1.weight", "image_encoder.blocks.6.mlp.lin1.bias", "image_encoder.blocks.6.mlp.lin2.weight", "image_encoder.blocks.6.mlp.lin2.bias", "image_encoder.blocks.7.norm1.weight", "image_encoder.blocks.7.norm1.bias", "image_encoder.blocks.7.attn.rel_pos_d", "image_encoder.blocks.7.attn.rel_pos_h", "image_encoder.blocks.7.attn.rel_pos_w", "image_encoder.blocks.7.attn.qkv.weight", "image_encoder.blocks.7.attn.qkv.bias", "image_encoder.blocks.7.attn.proj.weight", "image_encoder.blocks.7.attn.proj.bias", "image_encoder.blocks.7.norm2.weight", "image_encoder.blocks.7.norm2.bias", "image_encoder.blocks.7.mlp.lin1.weight", "image_encoder.blocks.7.mlp.lin1.bias", "image_encoder.blocks.7.mlp.lin2.weight", "image_encoder.blocks.7.mlp.lin2.bias", "image_encoder.blocks.8.norm1.weight", "image_encoder.blocks.8.norm1.bias", "image_encoder.blocks.8.attn.rel_pos_d", "image_encoder.blocks.8.attn.rel_pos_h", "image_encoder.blocks.8.attn.rel_pos_w", "image_encoder.blocks.8.attn.qkv.weight", "image_encoder.blocks.8.attn.qkv.bias", "image_encoder.blocks.8.attn.proj.weight", "image_encoder.blocks.8.attn.proj.bias", "image_encoder.blocks.8.norm2.weight", "image_encoder.blocks.8.norm2.bias", "image_encoder.blocks.8.mlp.lin1.weight", "image_encoder.blocks.8.mlp.lin1.bias", "image_encoder.blocks.8.mlp.lin2.weight", "image_encoder.blocks.8.mlp.lin2.bias", "image_encoder.blocks.9.norm1.weight", "image_encoder.blocks.9.norm1.bias", "image_encoder.blocks.9.attn.rel_pos_d", "image_encoder.blocks.9.attn.rel_pos_h", "image_encoder.blocks.9.attn.rel_pos_w", "image_encoder.blocks.9.attn.qkv.weight", "image_encoder.blocks.9.attn.qkv.bias", "image_encoder.blocks.9.attn.proj.weight", "image_encoder.blocks.9.attn.proj.bias", "image_encoder.blocks.9.norm2.weight", "image_encoder.blocks.9.norm2.bias", "image_encoder.blocks.9.mlp.lin1.weight", "image_encoder.blocks.9.mlp.lin1.bias", "image_encoder.blocks.9.mlp.lin2.weight", "image_encoder.blocks.9.mlp.lin2.bias", "image_encoder.blocks.10.norm1.weight", "image_encoder.blocks.10.norm1.bias", "image_encoder.blocks.10.attn.rel_pos_d", "image_encoder.blocks.10.attn.rel_pos_h", "image_encoder.blocks.10.attn.rel_pos_w", "image_encoder.blocks.10.attn.qkv.weight", "image_encoder.blocks.10.attn.qkv.bias", "image_encoder.blocks.10.attn.proj.weight", "image_encoder.blocks.10.attn.proj.bias", "image_encoder.blocks.10.norm2.weight", "image_encoder.blocks.10.norm2.bias", "image_encoder.blocks.10.mlp.lin1.weight", "image_encoder.blocks.10.mlp.lin1.bias", "image_encoder.blocks.10.mlp.lin2.weight", "image_encoder.blocks.10.mlp.lin2.bias", "image_encoder.blocks.11.norm1.weight", "image_encoder.blocks.11.norm1.bias", "image_encoder.blocks.11.attn.rel_pos_d", "image_encoder.blocks.11.attn.rel_pos_h", "image_encoder.blocks.11.attn.rel_pos_w", "image_encoder.blocks.11.attn.qkv.weight", "image_encoder.blocks.11.attn.qkv.bias", "image_encoder.blocks.11.attn.proj.weight", "image_encoder.blocks.11.attn.proj.bias", "image_encoder.blocks.11.norm2.weight", "image_encoder.blocks.11.norm2.bias", "image_encoder.blocks.11.mlp.lin1.weight", "image_encoder.blocks.11.mlp.lin1.bias", "image_encoder.blocks.11.mlp.lin2.weight", "image_encoder.blocks.11.mlp.lin2.bias". size mismatch for image_encoder.blocks.2.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.2.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.2.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]).
Can you try to use this building method temporarily?
https://github.com/arcadelab/FastSAM3D_slicer/blob/main/fastsam%2Fsegment_anything%2Fbuild_sam3D.py
I will update new checkpoints soon.
This worked, ty!