FastSAM3D icon indicating copy to clipboard operation
FastSAM3D copied to clipboard

Code for segmenting a single image

Open AlexanderZeilmann opened this issue 1 year ago • 9 comments

With the original segment anything all I have to do to segment a single image is downloading the checkpoint and running

from segment_anything import SamPredictor, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)

How can I do something similar in FastSAM3D? I downloaded the FastSAM3D checkpoint and have a 3D image with prompts ready. How do use FastSAM3D to segment my image using my prompts?

AlexanderZeilmann avatar Aug 29 '24 16:08 AlexanderZeilmann

You could use infer.sh, and modify the parameter as you want. vp means visualization path you want to output, tdp mean the image you want to used for segmenting.

python validation.py --seed 2023
-vp ./results/vis_sam_med3d
-tdp data/initial_test_dataset/total_segment -nc 1

skill-diver avatar Sep 16 '24 23:09 skill-diver

Regarding this issue, I have my data set in the specified format and have saved the FastSam3D checkpoint locally. When I run infer.sh I experience errors when loading the model. They are regarding missing encoder blocks. Additionally, when looking through the validation.py file I notice many functions refer to tuning the model. In my case, I do not want to tune the model. I want to test the model's segmentation on my data. Any help with this would be greatly appreciated.

Here is my error message below.

RuntimeError: Error(s) in loading state_dict for Sam3D: Missing key(s) in state_dict: "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_d", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_d", "image_encoder.blocks.1.attn.rel_pos_h", "image_encoder.blocks.1.attn.rel_pos_w", "image_encoder.blocks.1.attn.qkv.weight", "image_encoder.blocks.1.attn.qkv.bias", "image_encoder.blocks.1.attn.proj.weight", "image_encoder.blocks.1.attn.proj.bias", "image_encoder.blocks.6.norm1.weight", "image_encoder.blocks.6.norm1.bias", "image_encoder.blocks.6.attn.rel_pos_d", "image_encoder.blocks.6.attn.rel_pos_h", "image_encoder.blocks.6.attn.rel_pos_w", "image_encoder.blocks.6.attn.qkv.weight", "image_encoder.blocks.6.attn.qkv.bias", "image_encoder.blocks.6.attn.proj.weight", "image_encoder.blocks.6.attn.proj.bias", "image_encoder.blocks.6.norm2.weight", "image_encoder.blocks.6.norm2.bias", "image_encoder.blocks.6.mlp.lin1.weight", "image_encoder.blocks.6.mlp.lin1.bias", "image_encoder.blocks.6.mlp.lin2.weight", "image_encoder.blocks.6.mlp.lin2.bias", "image_encoder.blocks.7.norm1.weight", "image_encoder.blocks.7.norm1.bias", "image_encoder.blocks.7.attn.rel_pos_d", "image_encoder.blocks.7.attn.rel_pos_h", "image_encoder.blocks.7.attn.rel_pos_w", "image_encoder.blocks.7.attn.qkv.weight", "image_encoder.blocks.7.attn.qkv.bias", "image_encoder.blocks.7.attn.proj.weight", "image_encoder.blocks.7.attn.proj.bias", "image_encoder.blocks.7.norm2.weight", "image_encoder.blocks.7.norm2.bias", "image_encoder.blocks.7.mlp.lin1.weight", "image_encoder.blocks.7.mlp.lin1.bias", "image_encoder.blocks.7.mlp.lin2.weight", "image_encoder.blocks.7.mlp.lin2.bias", "image_encoder.blocks.8.norm1.weight", "image_encoder.blocks.8.norm1.bias", "image_encoder.blocks.8.attn.rel_pos_d", "image_encoder.blocks.8.attn.rel_pos_h", "image_encoder.blocks.8.attn.rel_pos_w", "image_encoder.blocks.8.attn.qkv.weight", "image_encoder.blocks.8.attn.qkv.bias", "image_encoder.blocks.8.attn.proj.weight", "image_encoder.blocks.8.attn.proj.bias", "image_encoder.blocks.8.norm2.weight", "image_encoder.blocks.8.norm2.bias", "image_encoder.blocks.8.mlp.lin1.weight", "image_encoder.blocks.8.mlp.lin1.bias", "image_encoder.blocks.8.mlp.lin2.weight", "image_encoder.blocks.8.mlp.lin2.bias", "image_encoder.blocks.9.norm1.weight", "image_encoder.blocks.9.norm1.bias", "image_encoder.blocks.9.attn.rel_pos_d", "image_encoder.blocks.9.attn.rel_pos_h", "image_encoder.blocks.9.attn.rel_pos_w", "image_encoder.blocks.9.attn.qkv.weight", "image_encoder.blocks.9.attn.qkv.bias", "image_encoder.blocks.9.attn.proj.weight", "image_encoder.blocks.9.attn.proj.bias", "image_encoder.blocks.9.norm2.weight", "image_encoder.blocks.9.norm2.bias", "image_encoder.blocks.9.mlp.lin1.weight", "image_encoder.blocks.9.mlp.lin1.bias", "image_encoder.blocks.9.mlp.lin2.weight", "image_encoder.blocks.9.mlp.lin2.bias", "image_encoder.blocks.10.norm1.weight", "image_encoder.blocks.10.norm1.bias", "image_encoder.blocks.10.attn.rel_pos_d", "image_encoder.blocks.10.attn.rel_pos_h", "image_encoder.blocks.10.attn.rel_pos_w", "image_encoder.blocks.10.attn.qkv.weight", "image_encoder.blocks.10.attn.qkv.bias", "image_encoder.blocks.10.attn.proj.weight", "image_encoder.blocks.10.attn.proj.bias", "image_encoder.blocks.10.norm2.weight", "image_encoder.blocks.10.norm2.bias", "image_encoder.blocks.10.mlp.lin1.weight", "image_encoder.blocks.10.mlp.lin1.bias", "image_encoder.blocks.10.mlp.lin2.weight", "image_encoder.blocks.10.mlp.lin2.bias", "image_encoder.blocks.11.norm1.weight", "image_encoder.blocks.11.norm1.bias", "image_encoder.blocks.11.attn.rel_pos_d", "image_encoder.blocks.11.attn.rel_pos_h", "image_encoder.blocks.11.attn.rel_pos_w", "image_encoder.blocks.11.attn.qkv.weight", "image_encoder.blocks.11.attn.qkv.bias", "image_encoder.blocks.11.attn.proj.weight", "image_encoder.blocks.11.attn.proj.bias", "image_encoder.blocks.11.norm2.weight", "image_encoder.blocks.11.norm2.bias", "image_encoder.blocks.11.mlp.lin1.weight", "image_encoder.blocks.11.mlp.lin1.bias", "image_encoder.blocks.11.mlp.lin2.weight", "image_encoder.blocks.11.mlp.lin2.bias". size mismatch for image_encoder.blocks.2.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.2.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.2.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]).

williamtrayn0r avatar Oct 09 '24 12:10 williamtrayn0r

I'm experiencing the same issue as @williamtrayn0r. There seems to be a mismatch between the architecture defined in build_sam3D.py and the checkpoint you provided. Could you help us resolve this?

noahrecher avatar Mar 16 '25 20:03 noahrecher

Which model setting are you using? Can you show me your model setting in buildsam3d? @noahrecher

skill-diver avatar Mar 17 '25 19:03 skill-diver

I am using the "vit_b_ori" ViT in buildsam3d and the latest checkpoints downloaded from your repository. Please let me know if you require further information. Thank you!

noahrecher avatar Mar 17 '25 20:03 noahrecher

Can you try to make line 196 "skip layer = 2" in buildsam3d.py ?

skill-diver avatar Mar 18 '25 01:03 skill-diver

I changed it but I'm still getting an error:

RuntimeError: Error(s) in loading state_dict for Sam3D: Missing key(s) in state_dict: "image_encoder.blocks.6.norm1.weight", "image_encoder.blocks.6.norm1.bias", "image_encoder.blocks.6.attn.rel_pos_d", "image_encoder.blocks.6.attn.rel_pos_h", "image_encoder.blocks.6.attn.rel_pos_w", "image_encoder.blocks.6.attn.qkv.weight", "image_encoder.blocks.6.attn.qkv.bias", "image_encoder.blocks.6.attn.proj.weight", "image_encoder.blocks.6.attn.proj.bias", "image_encoder.blocks.6.norm2.weight", "image_encoder.blocks.6.norm2.bias", "image_encoder.blocks.6.mlp.lin1.weight", "image_encoder.blocks.6.mlp.lin1.bias", "image_encoder.blocks.6.mlp.lin2.weight", "image_encoder.blocks.6.mlp.lin2.bias", "image_encoder.blocks.7.norm1.weight", "image_encoder.blocks.7.norm1.bias", "image_encoder.blocks.7.attn.rel_pos_d", "image_encoder.blocks.7.attn.rel_pos_h", "image_encoder.blocks.7.attn.rel_pos_w", "image_encoder.blocks.7.attn.qkv.weight", "image_encoder.blocks.7.attn.qkv.bias", "image_encoder.blocks.7.attn.proj.weight", "image_encoder.blocks.7.attn.proj.bias", "image_encoder.blocks.7.norm2.weight", "image_encoder.blocks.7.norm2.bias", "image_encoder.blocks.7.mlp.lin1.weight", "image_encoder.blocks.7.mlp.lin1.bias", "image_encoder.blocks.7.mlp.lin2.weight", "image_encoder.blocks.7.mlp.lin2.bias", "image_encoder.blocks.8.norm1.weight", "image_encoder.blocks.8.norm1.bias", "image_encoder.blocks.8.attn.rel_pos_d", "image_encoder.blocks.8.attn.rel_pos_h", "image_encoder.blocks.8.attn.rel_pos_w", "image_encoder.blocks.8.attn.qkv.weight", "image_encoder.blocks.8.attn.qkv.bias", "image_encoder.blocks.8.attn.proj.weight", "image_encoder.blocks.8.attn.proj.bias", "image_encoder.blocks.8.norm2.weight", "image_encoder.blocks.8.norm2.bias", "image_encoder.blocks.8.mlp.lin1.weight", "image_encoder.blocks.8.mlp.lin1.bias", "image_encoder.blocks.8.mlp.lin2.weight", "image_encoder.blocks.8.mlp.lin2.bias", "image_encoder.blocks.9.norm1.weight", "image_encoder.blocks.9.norm1.bias", "image_encoder.blocks.9.attn.rel_pos_d", "image_encoder.blocks.9.attn.rel_pos_h", "image_encoder.blocks.9.attn.rel_pos_w", "image_encoder.blocks.9.attn.qkv.weight", "image_encoder.blocks.9.attn.qkv.bias", "image_encoder.blocks.9.attn.proj.weight", "image_encoder.blocks.9.attn.proj.bias", "image_encoder.blocks.9.norm2.weight", "image_encoder.blocks.9.norm2.bias", "image_encoder.blocks.9.mlp.lin1.weight", "image_encoder.blocks.9.mlp.lin1.bias", "image_encoder.blocks.9.mlp.lin2.weight", "image_encoder.blocks.9.mlp.lin2.bias", "image_encoder.blocks.10.norm1.weight", "image_encoder.blocks.10.norm1.bias", "image_encoder.blocks.10.attn.rel_pos_d", "image_encoder.blocks.10.attn.rel_pos_h", "image_encoder.blocks.10.attn.rel_pos_w", "image_encoder.blocks.10.attn.qkv.weight", "image_encoder.blocks.10.attn.qkv.bias", "image_encoder.blocks.10.attn.proj.weight", "image_encoder.blocks.10.attn.proj.bias", "image_encoder.blocks.10.norm2.weight", "image_encoder.blocks.10.norm2.bias", "image_encoder.blocks.10.mlp.lin1.weight", "image_encoder.blocks.10.mlp.lin1.bias", "image_encoder.blocks.10.mlp.lin2.weight", "image_encoder.blocks.10.mlp.lin2.bias", "image_encoder.blocks.11.norm1.weight", "image_encoder.blocks.11.norm1.bias", "image_encoder.blocks.11.attn.rel_pos_d", "image_encoder.blocks.11.attn.rel_pos_h", "image_encoder.blocks.11.attn.rel_pos_w", "image_encoder.blocks.11.attn.qkv.weight", "image_encoder.blocks.11.attn.qkv.bias", "image_encoder.blocks.11.attn.proj.weight", "image_encoder.blocks.11.attn.proj.bias", "image_encoder.blocks.11.norm2.weight", "image_encoder.blocks.11.norm2.bias", "image_encoder.blocks.11.mlp.lin1.weight", "image_encoder.blocks.11.mlp.lin1.bias", "image_encoder.blocks.11.mlp.lin2.weight", "image_encoder.blocks.11.mlp.lin2.bias". size mismatch for image_encoder.blocks.2.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.2.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.2.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.3.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.4.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]). size mismatch for image_encoder.blocks.5.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]).

noahrecher avatar Mar 18 '25 07:03 noahrecher

Can you try to use this building method temporarily?

https://github.com/arcadelab/FastSAM3D_slicer/blob/main/fastsam%2Fsegment_anything%2Fbuild_sam3D.py

I will update new checkpoints soon.

skill-diver avatar Mar 25 '25 00:03 skill-diver

This worked, ty!

noahrecher avatar Mar 25 '25 13:03 noahrecher