sam-hq icon indicating copy to clipboard operation
sam-hq copied to clipboard

Correlation Between Fine-Tuning Model Files and Demo Pretrained Models

Open RyosukeSakaguchi opened this issue 1 year ago • 6 comments

Hi! Thank you for the amazing repository!

For the fine-tuning process, you seem to provide pre-trained models where each model type comes with two model files. For instance, for the model type "vit_l", there are "sam_vit_l_0b3195.pth" and "sam_vit_l_maskdecoder.pth" provided as pre-trained models. It appears that these two model files are merged through fine-tuning to produce a single model file. https://github.com/SysCV/sam-hq/tree/main/train#expected-checkpoint

However, in the demo, one model file is provided for each model type as a pretrained model. https://github.com/SysCV/sam-hq#model-checkpoints

I assume that the "ViT-L HQ-SAM model" used in the demo corresponds to the pre-trained models "sam_vit_l_0b3195.pth" and "sam_vit_l_maskdecoder.pth" used for fine-tuning. Is my assumption correct?

RyosukeSakaguchi avatar Oct 11 '23 01:10 RyosukeSakaguchi

Hi, thanks for watching our work!

Yes, we have two pre-trained checkpoints for the initialization of training. And they are merged at the end of training in this line. The merged checkpoint is the HQ-SAM used in the demo.

We split it into two parts just to save storage space. Since we only add learnable parameters in the decoder and only need to save learned lightweight decoder parameters in every epoch.

ymq2017 avatar Oct 12 '23 21:10 ymq2017

Thank you for your reply! It's very helpful.

What kind of data were the two pre-trained checkpoints trained on?

RyosukeSakaguchi avatar Oct 13 '23 03:10 RyosukeSakaguchi

It is trained on SA-1B dataset proposed in SAM.

ymq2017 avatar Oct 13 '23 22:10 ymq2017

Thank you! So, how do we fine-tune the following checkpoints as a pre-trained model with custom data? https://github.com/SysCV/sam-hq#model-checkpoints

The models provided here are merged checkpoints, so I think we need to split the merged model to fine-tune it. If my understanding is correct, could you provide the code to split the merged model?

RyosukeSakaguchi avatar Oct 21 '23 12:10 RyosukeSakaguchi

I am experimenting to determine whether it's feasible to perform fine-tuning using a model file generated by reversing this particular procedure.

sam_hq = torch.load("pretrained_checkpoint/sam_hq_vit_l.pth", map_location=torch.device('cpu') )

sam_ckpt = {}
hq_decoder = {}

mask_decoder_keys = []
for key in  sam_hq.keys():
    if key.startswith('mask_decoder.'):
        mask_decoder_keys.append(key)
        
for  mask_decoder_key in  mask_decoder_keys:
        hq_decoder[mask_decoder_key.replace("mask_decoder.", "")] = sam_hq[ mask_decoder_key]
        sam_hq.pop( mask_decoder_key)

torch.save(sam_hq, "pretrained_checkpoint/sam_vit_l.pth")
torch.save(hq_decoder, "pretrained_checkpoint/sam_vit_l_maskdecoder.pth")

RyosukeSakaguchi avatar Oct 23 '23 07:10 RyosukeSakaguchi

Hi! As a result of trial and error, the training was successful when I broke it down as follows. The parameters contained in the no_need_keys are present in sam_hq_vit_l.pth, yet they seem to be unnecessary for both sam_vit_l.pth and sam_vit_l_maskdecoder.pth. If I include the no_need_keys in either of the model file dictionaries, I get an "Unexpected key(s) in state_dict" error. Conversely, when I delete the no_need_keys, the training succeeds without any errors. What kind of parameters are these no_need_keys?

no_need_keys = ["hf_token.weight", "hf_mlp.layers.0.weight", "hf_mlp.layers.0.bias", "hf_mlp.layers.1.weight", "hf_mlp.layers.1.bias", "hf_mlp.layers.2.weight", "hf_mlp.layers.2.bias", "compress_vit_feat.0.weight", "compress_vit_feat.0.bias", "compress_vit_feat.1.weight", "compress_vit_feat.1.bias", "compress_vit_feat.3.weight", "compress_vit_feat.3.bias", "embedding_encoder.0.weight", "embedding_encoder.0.bias", "embedding_encoder.1.weight", "embedding_encoder.1.bias", "embedding_encoder.3.weight", "embedding_encoder.3.bias", "embedding_maskfeature.0.weight", "embedding_maskfeature.0.bias", "embedding_maskfeature.1.weight", "embedding_maskfeature.1.bias", "embedding_maskfeature.3.weight", "embedding_maskfeature.3.bias"]

sam_hq = torch.load("pretrained_checkpoint/sam_hq_vit_l.pth", map_location=torch.device('cpu') )

sam_ckpt = {}
hq_decoder = {}

mask_decoder_keys = []
for key in  sam_hq.keys():
    if key.startswith('mask_decoder.') and key.replace("mask_decoder.", "") not in no_need_keys:
        mask_decoder_keys.append(key)
        
for  mask_decoder_key in  mask_decoder_keys:
        hq_decoder[mask_decoder_key.replace("mask_decoder.", "")] = sam_hq[ mask_decoder_key]

for no_need_key in no_need_keys:
    sam_hq.pop( f"mask_decoder.{no_need_key}")

torch.save(sam_hq, "pretrained_checkpoint/sam_vit_l.pth")
torch.save(hq_decoder, "pretrained_checkpoint/sam_vit_l_maskdecoder.pth")

RyosukeSakaguchi avatar Oct 23 '23 11:10 RyosukeSakaguchi