ggml icon indicating copy to clipboard operation
ggml copied to clipboard

examples : add sample SAM inference

Open ggerganov opened this issue 1 year ago • 0 comments

WIP IN PROGRESS

  • hparams: https://github.com/facebookresearch/segment-anything/blob/efeab7296ab579d4a261e554eca80faf6b33924a/segment_anything/build_sam.py#L13-L44
  • cmd:
    python scripts/amg.py --checkpoint ./sam_vit_b_01ec64.pth --model-type vit_b --input img.jpg --output img.out --device cpu
    
PTH tensors for ViT-B
image_encoder.neck.0.weight torch.Size([256, 768, 1, 1])
image_encoder.neck.1.weight torch.Size([256])
image_encoder.neck.1.bias torch.Size([256])
image_encoder.neck.2.weight torch.Size([256, 256, 3, 3])
image_encoder.neck.3.weight torch.Size([256])
image_encoder.neck.3.bias torch.Size([256])
image_encoder.patch_embed.proj.weight torch.Size([768, 3, 16, 16])
image_encoder.patch_embed.proj.bias torch.Size([768])
image_encoder.blocks.0.norm1.weight torch.Size([768])
image_encoder.blocks.0.norm1.bias torch.Size([768])
image_encoder.blocks.0.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.0.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.0.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.0.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.0.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.0.attn.proj.bias torch.Size([768])
image_encoder.blocks.0.norm2.weight torch.Size([768])
image_encoder.blocks.0.norm2.bias torch.Size([768])
image_encoder.blocks.0.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.0.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.0.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.0.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.1.norm1.weight torch.Size([768])
image_encoder.blocks.1.norm1.bias torch.Size([768])
image_encoder.blocks.1.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.1.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.1.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.1.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.1.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.1.attn.proj.bias torch.Size([768])
image_encoder.blocks.1.norm2.weight torch.Size([768])
image_encoder.blocks.1.norm2.bias torch.Size([768])
image_encoder.blocks.1.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.1.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.1.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.1.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.2.norm1.weight torch.Size([768])
image_encoder.blocks.2.norm1.bias torch.Size([768])
image_encoder.blocks.2.attn.rel_pos_h torch.Size([127, 64])
image_encoder.blocks.2.attn.rel_pos_w torch.Size([127, 64])
image_encoder.blocks.2.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.2.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.2.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.2.attn.proj.bias torch.Size([768])
image_encoder.blocks.2.norm2.weight torch.Size([768])
image_encoder.blocks.2.norm2.bias torch.Size([768])
image_encoder.blocks.2.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.2.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.2.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.2.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.3.norm1.weight torch.Size([768])
image_encoder.blocks.3.norm1.bias torch.Size([768])
image_encoder.blocks.3.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.3.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.3.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.3.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.3.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.3.attn.proj.bias torch.Size([768])
image_encoder.blocks.3.norm2.weight torch.Size([768])
image_encoder.blocks.3.norm2.bias torch.Size([768])
image_encoder.blocks.3.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.3.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.3.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.3.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.4.norm1.weight torch.Size([768])
image_encoder.blocks.4.norm1.bias torch.Size([768])
image_encoder.blocks.4.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.4.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.4.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.4.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.4.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.4.attn.proj.bias torch.Size([768])
image_encoder.blocks.4.norm2.weight torch.Size([768])
image_encoder.blocks.4.norm2.bias torch.Size([768])
image_encoder.blocks.4.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.4.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.4.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.4.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.5.norm1.weight torch.Size([768])
image_encoder.blocks.5.norm1.bias torch.Size([768])
image_encoder.blocks.5.attn.rel_pos_h torch.Size([127, 64])
image_encoder.blocks.5.attn.rel_pos_w torch.Size([127, 64])
image_encoder.blocks.5.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.5.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.5.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.5.attn.proj.bias torch.Size([768])
image_encoder.blocks.5.norm2.weight torch.Size([768])
image_encoder.blocks.5.norm2.bias torch.Size([768])
image_encoder.blocks.5.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.5.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.5.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.5.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.6.norm1.weight torch.Size([768])
image_encoder.blocks.6.norm1.bias torch.Size([768])
image_encoder.blocks.6.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.6.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.6.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.6.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.6.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.6.attn.proj.bias torch.Size([768])
image_encoder.blocks.6.norm2.weight torch.Size([768])
image_encoder.blocks.6.norm2.bias torch.Size([768])
image_encoder.blocks.6.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.6.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.6.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.6.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.7.norm1.weight torch.Size([768])
image_encoder.blocks.7.norm1.bias torch.Size([768])
image_encoder.blocks.7.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.7.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.7.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.7.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.7.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.7.attn.proj.bias torch.Size([768])
image_encoder.blocks.7.norm2.weight torch.Size([768])
image_encoder.blocks.7.norm2.bias torch.Size([768])
image_encoder.blocks.7.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.7.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.7.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.7.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.8.norm1.weight torch.Size([768])
image_encoder.blocks.8.norm1.bias torch.Size([768])
image_encoder.blocks.8.attn.rel_pos_h torch.Size([127, 64])
image_encoder.blocks.8.attn.rel_pos_w torch.Size([127, 64])
image_encoder.blocks.8.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.8.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.8.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.8.attn.proj.bias torch.Size([768])
image_encoder.blocks.8.norm2.weight torch.Size([768])
image_encoder.blocks.8.norm2.bias torch.Size([768])
image_encoder.blocks.8.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.8.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.8.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.8.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.9.norm1.weight torch.Size([768])
image_encoder.blocks.9.norm1.bias torch.Size([768])
image_encoder.blocks.9.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.9.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.9.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.9.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.9.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.9.attn.proj.bias torch.Size([768])
image_encoder.blocks.9.norm2.weight torch.Size([768])
image_encoder.blocks.9.norm2.bias torch.Size([768])
image_encoder.blocks.9.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.9.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.9.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.9.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.10.norm1.weight torch.Size([768])
image_encoder.blocks.10.norm1.bias torch.Size([768])
image_encoder.blocks.10.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.10.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.10.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.10.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.10.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.10.attn.proj.bias torch.Size([768])
image_encoder.blocks.10.norm2.weight torch.Size([768])
image_encoder.blocks.10.norm2.bias torch.Size([768])
image_encoder.blocks.10.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.10.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.10.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.10.mlp.lin2.bias torch.Size([768])
image_encoder.blocks.11.norm1.weight torch.Size([768])
image_encoder.blocks.11.norm1.bias torch.Size([768])
image_encoder.blocks.11.attn.rel_pos_h torch.Size([127, 64])
image_encoder.blocks.11.attn.rel_pos_w torch.Size([127, 64])
image_encoder.blocks.11.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.11.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.11.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.11.attn.proj.bias torch.Size([768])
image_encoder.blocks.11.norm2.weight torch.Size([768])
image_encoder.blocks.11.norm2.bias torch.Size([768])
image_encoder.blocks.11.mlp.lin1.weight torch.Size([3072, 768])
image_encoder.blocks.11.mlp.lin1.bias torch.Size([3072])
image_encoder.blocks.11.mlp.lin2.weight torch.Size([768, 3072])
image_encoder.blocks.11.mlp.lin2.bias torch.Size([768])
prompt_encoder.pe_layer.positional_encoding_gaussian_matrix torch.Size([2, 128])
mask_decoder.transformer.layers.0.self_attn.q_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.0.self_attn.q_proj.bias torch.Size([256])
mask_decoder.transformer.layers.0.self_attn.k_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.0.self_attn.k_proj.bias torch.Size([256])
mask_decoder.transformer.layers.0.self_attn.v_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.0.self_attn.v_proj.bias torch.Size([256])
mask_decoder.transformer.layers.0.self_attn.out_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.0.self_attn.out_proj.bias torch.Size([256])
mask_decoder.transformer.layers.0.norm1.weight torch.Size([256])
mask_decoder.transformer.layers.0.norm1.bias torch.Size([256])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.weight torch.Size([256, 128])
mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.bias torch.Size([256])
mask_decoder.transformer.layers.0.norm2.weight torch.Size([256])
mask_decoder.transformer.layers.0.norm2.bias torch.Size([256])
mask_decoder.transformer.layers.0.mlp.lin1.weight torch.Size([2048, 256])
mask_decoder.transformer.layers.0.mlp.lin1.bias torch.Size([2048])
mask_decoder.transformer.layers.0.mlp.lin2.weight torch.Size([256, 2048])
mask_decoder.transformer.layers.0.mlp.lin2.bias torch.Size([256])
mask_decoder.transformer.layers.0.norm3.weight torch.Size([256])
mask_decoder.transformer.layers.0.norm3.bias torch.Size([256])
mask_decoder.transformer.layers.0.norm4.weight torch.Size([256])
mask_decoder.transformer.layers.0.norm4.bias torch.Size([256])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.bias torch.Size([128])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.weight torch.Size([256, 128])
mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.self_attn.q_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.1.self_attn.q_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.self_attn.k_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.1.self_attn.k_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.self_attn.v_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.1.self_attn.v_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.self_attn.out_proj.weight torch.Size([256, 256])
mask_decoder.transformer.layers.1.self_attn.out_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.norm1.weight torch.Size([256])
mask_decoder.transformer.layers.1.norm1.bias torch.Size([256])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.weight torch.Size([256, 128])
mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.bias torch.Size([256])
mask_decoder.transformer.layers.1.norm2.weight torch.Size([256])
mask_decoder.transformer.layers.1.norm2.bias torch.Size([256])
mask_decoder.transformer.layers.1.mlp.lin1.weight torch.Size([2048, 256])
mask_decoder.transformer.layers.1.mlp.lin1.bias torch.Size([2048])
mask_decoder.transformer.layers.1.mlp.lin2.weight torch.Size([256, 2048])
mask_decoder.transformer.layers.1.mlp.lin2.bias torch.Size([256])
mask_decoder.transformer.layers.1.norm3.weight torch.Size([256])
mask_decoder.transformer.layers.1.norm3.bias torch.Size([256])
mask_decoder.transformer.layers.1.norm4.weight torch.Size([256])
mask_decoder.transformer.layers.1.norm4.bias torch.Size([256])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.weight torch.Size([128, 256])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.bias torch.Size([128])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.weight torch.Size([256, 128])
mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.bias torch.Size([256])
mask_decoder.transformer.final_attn_token_to_image.q_proj.weight torch.Size([128, 256])
mask_decoder.transformer.final_attn_token_to_image.q_proj.bias torch.Size([128])
mask_decoder.transformer.final_attn_token_to_image.k_proj.weight torch.Size([128, 256])
mask_decoder.transformer.final_attn_token_to_image.k_proj.bias torch.Size([128])
mask_decoder.transformer.final_attn_token_to_image.v_proj.weight torch.Size([128, 256])
mask_decoder.transformer.final_attn_token_to_image.v_proj.bias torch.Size([128])
mask_decoder.transformer.final_attn_token_to_image.out_proj.weight torch.Size([256, 128])
mask_decoder.transformer.final_attn_token_to_image.out_proj.bias torch.Size([256])
mask_decoder.transformer.norm_final_attn.weight torch.Size([256])
mask_decoder.transformer.norm_final_attn.bias torch.Size([256])
prompt_encoder.point_embeddings.0.weight torch.Size([1, 256])
prompt_encoder.point_embeddings.1.weight torch.Size([1, 256])
prompt_encoder.point_embeddings.2.weight torch.Size([1, 256])
prompt_encoder.point_embeddings.3.weight torch.Size([1, 256])
prompt_encoder.not_a_point_embed.weight torch.Size([1, 256])
mask_decoder.output_upscaling.0.weight torch.Size([256, 64, 2, 2])
mask_decoder.output_upscaling.0.bias torch.Size([64])
mask_decoder.output_upscaling.1.weight torch.Size([64])
mask_decoder.output_upscaling.1.bias torch.Size([64])
mask_decoder.output_upscaling.3.weight torch.Size([64, 32, 2, 2])
mask_decoder.output_upscaling.3.bias torch.Size([32])
mask_decoder.output_hypernetworks_mlps.0.layers.0.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.0.layers.0.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.0.layers.1.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.0.layers.1.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.0.layers.2.weight torch.Size([32, 256])
mask_decoder.output_hypernetworks_mlps.0.layers.2.bias torch.Size([32])
mask_decoder.output_hypernetworks_mlps.1.layers.0.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.1.layers.0.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.1.layers.1.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.1.layers.1.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.1.layers.2.weight torch.Size([32, 256])
mask_decoder.output_hypernetworks_mlps.1.layers.2.bias torch.Size([32])
mask_decoder.output_hypernetworks_mlps.2.layers.0.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.2.layers.0.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.2.layers.1.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.2.layers.1.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.2.layers.2.weight torch.Size([32, 256])
mask_decoder.output_hypernetworks_mlps.2.layers.2.bias torch.Size([32])
mask_decoder.output_hypernetworks_mlps.3.layers.0.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.3.layers.0.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.3.layers.1.weight torch.Size([256, 256])
mask_decoder.output_hypernetworks_mlps.3.layers.1.bias torch.Size([256])
mask_decoder.output_hypernetworks_mlps.3.layers.2.weight torch.Size([32, 256])
mask_decoder.output_hypernetworks_mlps.3.layers.2.bias torch.Size([32])
prompt_encoder.mask_downscaling.0.weight torch.Size([4, 1, 2, 2])
prompt_encoder.mask_downscaling.0.bias torch.Size([4])
prompt_encoder.mask_downscaling.1.weight torch.Size([4])
prompt_encoder.mask_downscaling.1.bias torch.Size([4])
prompt_encoder.mask_downscaling.3.weight torch.Size([16, 4, 2, 2])
prompt_encoder.mask_downscaling.3.bias torch.Size([16])
prompt_encoder.mask_downscaling.4.weight torch.Size([16])
prompt_encoder.mask_downscaling.4.bias torch.Size([16])
prompt_encoder.mask_downscaling.6.weight torch.Size([256, 16, 1, 1])
prompt_encoder.mask_downscaling.6.bias torch.Size([256])
prompt_encoder.no_mask_embed.weight torch.Size([1, 256])
mask_decoder.iou_prediction_head.layers.0.weight torch.Size([256, 256])
mask_decoder.iou_prediction_head.layers.0.bias torch.Size([256])
mask_decoder.iou_prediction_head.layers.1.weight torch.Size([256, 256])
mask_decoder.iou_prediction_head.layers.1.bias torch.Size([256])
mask_decoder.iou_prediction_head.layers.2.weight torch.Size([4, 256])
mask_decoder.iou_prediction_head.layers.2.bias torch.Size([4])
mask_decoder.iou_token.weight torch.Size([1, 256])
mask_decoder.mask_tokens.weight torch.Size([4, 256])
image_encoder.pos_embed torch.Size([1, 64, 64, 768])

ggerganov avatar Apr 09 '23 08:04 ggerganov