sam-hq icon indicating copy to clipboard operation
sam-hq copied to clipboard

Evaluation script to reproduce numbers in SAM-HQ paper

Open ankitgoyalumd opened this issue 1 year ago • 4 comments

Great work! I am looking for a script that allows to reproduce iou and boundary IOU numbers. I looked into the train folder and there is an evaluation example shown. However it uses the checkpoint sam_vit_l_0b3195.pth

The predicted masks from this checkpoint of extreme poor quality leading me to believe I should have been using sam_hq_vit_l.pth shown in the main readme of the repo. However when I pass sam_hq_vit_l.pth, to the argument checkpoint of train.py along with flag --eval, it fails to load the checkpoint and errors out since keys do not match.

Please advise how I can reproduce results.

ankitgoyalumd avatar Jul 22 '23 03:07 ankitgoyalumd

Hi, the evaluation script in the train folder is python -m torch.distributed.launch --nproc_per_node=1 train.py --checkpoint ./pretrained_checkpoint/sam_vit_l_0b3195.pth --model-type vit_l --output work_dirs/hq_sam_l --eval --restore-model work_dirs/hq_sam_l/epoch_11.pth In this script, we load sam_vit_l_0b3195.pth for the encoder output and load the additional parameters trained by ours with the argument --restore-model. In training, we only learn a small number of parameters and save them. For evaluation on HQ dataset, we only need to load this group of parameters. An example checkpoint for --restore-model can be found here. You can also train it yourself.

ymq2017 avatar Jul 24 '23 22:07 ymq2017

Hi,

Thanks for the response. Can you please confirm if the published results can be reproduced using this checkpoint. I tried --restore model from with -restore-model ./pretrained_checkpoint/sam_hq_vit_l.pth and got following errors:

RuntimeError: Error(s) in loading state_dict for MaskDecoderHQ: Missing key(s) in state_dict: "transformer.layers.0.self_attn.q_proj.weight", "transformer.layers.0.self_attn.q_proj.bias", "transformer.layers.0.self_attn.k_proj.weight", "transformer.layers.0.self_attn.k_proj.bias", "transformer.layers.0.self_attn.v_proj.weight", "transformer.layers.0.self_attn.v_proj.bias", "transformer.layers.0.self_attn.out_proj.weight", "transformer.layers.0.self_attn.out_proj.bias", "transformer.layers.0.norm1.weight", "transformer.layers.0.norm1.bias", "transformer.layers.0.cross_attn_token_to_image.q_proj.weight", "transformer.layers.0.cross_attn_token_to_image.q_proj.bias", "transformer.layers.0.cross_attn_token_to_image.k_proj.weight", "transformer.layers.0.cross_attn_token_to_image.k_proj.bias", "transformer.layers.0.cross_attn_token_to_image.v_proj.weight", "transformer.layers.0.cross_attn_token_to_image.v_proj.bias", "transformer.layers.0.cross_attn_token_to_image.out_proj.weight", "transformer.layers.0.cross_attn_token_to_image.out_proj.bias", "transformer.layers.0.norm2.weight", "transformer.layers.0.norm2.bias", "transformer.layers.0.mlp.lin1.weight", "transformer.layers.0.mlp.lin1.bias", "transformer.layers.0.mlp.lin2.weight", "transformer.layers.0.mlp.lin2.bias", "transformer.layers.0.norm3.weight", "transformer.layers.0.norm3.bias", "transformer.layers.0.norm4.weight", "transformer.layers.0.norm4.bias", "transformer.layers.0.cross_attn_image_to_token.q_proj.weight", "transformer.layers.0.cross_attn_image_to_token.q_proj.bias", "transformer.layers.0.cross_attn_image_to_token.k_proj.weight", "transformer.layers.0.cross_attn_image_to_token.k_proj.bias", "transformer.layers.0.cross_attn_image_to_token.v_proj.weight", "transformer.layers.0.cross_attn_image_to_token.v_proj.bias", "transformer.layers.0.cross_attn_image_to_token.out_proj.weight", "transformer.layers.0.cross_attn_image_to_token.out_proj.bias", "transformer.layers.1.self_attn.q_proj.weight", "transformer.layers.1.self_attn.q_proj.bias", "transformer.layers.1.self_attn.k_proj.weight", "transformer.layers.1.self_attn.k_proj.bias", "transformer.layers.1.self_attn.v_proj.weight", "transformer.layers.1.self_attn.v_proj.bias", "transformer.layers.1.self_attn.out_proj.weight", "transformer.layers.1.self_attn.out_proj.bias", "transformer.layers.1.norm1.weight", "transformer.layers.1.norm1.bias", "transformer.layers.1.cross_attn_token_to_image.q_proj.weight", "transformer.layers.1.cross_attn_token_to_image.q_proj.bias", "transformer.layers.1.cross_attn_token_to_image.k_proj.weight", "transformer.layers.1.cross_attn_token_to_image.k_proj.bias", "transformer.layers.1.cross_attn_token_to_image.v_proj.weight", "transformer.layers.1.cross_attn_token_to_image.v_proj.bias", "transformer.layers.1.cross_attn_token_to_image.out_proj.weight", "transformer.layers.1.cross_attn_token_to_image.out_proj.bias", "transformer.layers.1.norm2.weight", "transformer.layers.1.norm2.bias", "transformer.layers.1.mlp.lin1.weight", "transformer.layers.1.mlp.lin1.bias", "transformer.layers.1.mlp.lin2.weight", "transformer.layers.1.mlp.lin2.bias", "transformer.layers.1.norm3.weight", "transformer.layers.1.norm3.bias", "transformer.layers.1.norm4.weight", "transformer.layers.1.norm4.bias", "transformer.layers.1.cross_attn_image_to_token.q_proj.weight", "transformer.layers.1.cross_attn_image_to_token.q_proj.bias", "transformer.layers.1.cross_attn_image_to_token.k_proj.weight", "transformer.layers.1.cross_attn_image_to_token.k_proj.bias", "transformer.layers.1.cross_attn_image_to_token.v_proj.weight", "transformer.layers.1.cross_attn_image_to_token.v_proj.bias", "transformer.layers.1.cross_attn_image_to_token.out_proj.weight", "transformer.layers.1.cross_attn_image_to_token.out_proj.bias", "transformer.final_attn_token_to_image.q_proj.weight", "transformer.final_attn_token_to_image.q_proj.bias", "transformer.final_attn_token_to_image.k_proj.weight", "transformer.final_attn_token_to_image.k_proj.bias", "transformer.final_attn_token_to_image.v_proj.weight", "transformer.final_attn_token_to_image.v_proj.bias", "transformer.final_attn_token_to_image.out_proj.weight", "transformer.final_attn_token_to_image.out_proj.bias", "transformer.norm_final_attn.weight", "transformer.norm_final_attn.bias", "iou_token.weight", "mask_tokens.weight", "output_upscaling.0.weight", "output_upscaling.0.bias", "output_upscaling.1.weight", "output_upscaling.1.bias", "output_upscaling.3.weight", "output_upscaling.3.bias", "output_hypernetworks_mlps.0.layers.0.weight", "output_hypernetworks_mlps.0.layers.0.bias", "output_hypernetworks_mlps.0.layers.1.weight", "output_hypernetworks_mlps.0.layers.1.bias", "output_hypernetworks_mlps.0.layers.2.weight", "output_hypernetworks_mlps.0.layers.2.bias", "output_hypernetworks_mlps.1.layers.0.weight", "output_hypernetworks_mlps.1.layers.0.bias", "output_hypernetworks_mlps.1.layers.1.weight", "output_hypernetworks_mlps.1.layers.1.bias", "output_hypernetworks_mlps.1.layers.2.weight", "output_hypernetworks_mlps.1.layers.2.bias", "output_hypernetworks_mlps.2.layers.0.weight", "output_hypernetworks_mlps.2.layers.0.bias", "output_hypernetworks_mlps.2.layers.1.weight", "output_hypernetworks_mlps.2.layers.1.bias", "output_hypernetworks_mlps.2.layers.2.weight", "output_hypernetworks_mlps.2.layers.2.bias", "output_hypernetworks_mlps.3.layers.0.weight", "output_hypernetworks_mlps.3.layers.0.bias", "output_hypernetworks_mlps.3.layers.1.weight", "output_hypernetworks_mlps.3.layers.1.bias", "output_hypernetworks_mlps.3.layers.2.weight", "output_hypernetworks_mlps.3.layers.2.bias", "iou_prediction_head.layers.0.weight", "iou_prediction_head.layers.0.bias", "iou_prediction_head.layers.1.weight", "iou_prediction_head.layers.1.bias", "iou_prediction_head.layers.2.weight", "iou_prediction_head.layers.2.bias", "hf_token.weight", "hf_mlp.layers.0.weight", "hf_mlp.layers.0.bias", "hf_mlp.layers.1.weight", "hf_mlp.layers.1.bias", "hf_mlp.layers.2.weight", "hf_mlp.layers.2.bias", "compress_vit_feat.0.weight", "compress_vit_feat.0.bias", "compress_vit_feat.1.weight", "compress_vit_feat.1.bias", "compress_vit_feat.3.weight", "compress_vit_feat.3.bias", "embedding_encoder.0.weight", "embedding_encoder.0.bias", "embedding_encoder.1.weight", "embedding_encoder.1.bias", "embedding_encoder.3.weight", "embedding_encoder.3.bias", "embedding_maskfeature.0.weight", "embedding_maskfeature.0.bias", "embedding_maskfeature.1.weight", "embedding_maskfeature.1.bias", "embedding_maskfeature.3.weight", "embedding_maskfeature.3.bias". Unexpected key(s) in state_dict: "image_encoder.neck.0.weight", "image_encoder.neck.1.weight", "image_encoder.neck.1.bias", "image_encoder.neck.2.weight", "image_encoder.neck.3.weight", "image_encoder.neck.3.bias", "image_encoder.patch_embed.proj.weight", "image_encoder.patch_embed.proj.bias", "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encod

Please advise, which checkpoint to use to reproduce results.

Thanks, Ankit

On Mon, Jul 24, 2023 at 3:42 PM Mingqiao Ye @.***> wrote:

Hi, the evaluation script in the train folder is python -m torch.distributed.launch --nproc_per_node=1 train.py --checkpoint ./pretrained_checkpoint/sam_vit_l_0b3195.pth --model-type vit_l --output work_dirs/hq_sam_l --eval --restore-model work_dirs/hq_sam_l/epoch_11.pth In this script, we load sam_vit_l_0b3195.pth for the encoder output and load the additional parameters trained by ours with the argument --restore-model. In training, we only learn a small number of parameters and save them. For evaluation on HQ dataset, we only need to load this group of parameters. An example checkpoint for --restore-model can be found here https://drive.google.com/file/d/1Ml41Wwl5QIKjoIyRYp2GgmONovKVJ0-C/view?usp=sharing. You can also train it yourself.

— Reply to this email directly, view it on GitHub https://github.com/SysCV/sam-hq/issues/46#issuecomment-1648721670, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABW53JGPEMH7FUCDUK2YTPDXR3253ANCNFSM6AAAAAA2TSFNPM . You are receiving this because you authored the thread.Message ID: @.***>

ankitgoyalumd avatar Jul 30 '23 21:07 ankitgoyalumd

Hi, when evaluating the four HQ datasets, we only need to load the mask_decoder of hq_sam. We do it this way because it saves storage space during training. For example, --restore-model work_dirs/hq_sam_l/epoch_11.pth Here epoch_11.pth is the decoder part of the sam_hq_vit_l.pth. An example checkpoint for --restore-model can be found in this link. You can also train it yourself.

ymq2017 avatar Jul 31 '23 23:07 ymq2017

I faced this error too. I used the command python3 -m torch.distributed.launch --nproc_per_node=1 train.py --checkpoint ./pretrained_checkpoint/sam_vit_l_0b3195.pth --model-type vit_l --output work_dirs/hq_sam_l/ --eval --restore-model work_dirs/hq_sam_l/sam_hq_epoch_11.pth where sam_hq_epoch_11.pth was the final checkpoint that the training had saved. However when I tried with any other previous checkpoints, it worked as expected.

vishakhalall avatar Aug 17 '23 04:08 vishakhalall