sam-hq
sam-hq copied to clipboard
Correlation Between Fine-Tuning Model Files and Demo Pretrained Models
Hi! Thank you for the amazing repository!
For the fine-tuning process, you seem to provide pre-trained models where each model type comes with two model files. For instance, for the model type "vit_l", there are "sam_vit_l_0b3195.pth" and "sam_vit_l_maskdecoder.pth" provided as pre-trained models. It appears that these two model files are merged through fine-tuning to produce a single model file. https://github.com/SysCV/sam-hq/tree/main/train#expected-checkpoint
However, in the demo, one model file is provided for each model type as a pretrained model. https://github.com/SysCV/sam-hq#model-checkpoints
I assume that the "ViT-L HQ-SAM model" used in the demo corresponds to the pre-trained models "sam_vit_l_0b3195.pth" and "sam_vit_l_maskdecoder.pth" used for fine-tuning. Is my assumption correct?
Hi, thanks for watching our work!
Yes, we have two pre-trained checkpoints for the initialization of training. And they are merged at the end of training in this line. The merged checkpoint is the HQ-SAM used in the demo.
We split it into two parts just to save storage space. Since we only add learnable parameters in the decoder and only need to save learned lightweight decoder parameters in every epoch.
Thank you for your reply! It's very helpful.
What kind of data were the two pre-trained checkpoints trained on?
It is trained on SA-1B dataset proposed in SAM.
Thank you! So, how do we fine-tune the following checkpoints as a pre-trained model with custom data? https://github.com/SysCV/sam-hq#model-checkpoints
The models provided here are merged checkpoints, so I think we need to split the merged model to fine-tune it. If my understanding is correct, could you provide the code to split the merged model?
I am experimenting to determine whether it's feasible to perform fine-tuning using a model file generated by reversing this particular procedure.
sam_hq = torch.load("pretrained_checkpoint/sam_hq_vit_l.pth", map_location=torch.device('cpu') )
sam_ckpt = {}
hq_decoder = {}
mask_decoder_keys = []
for key in sam_hq.keys():
if key.startswith('mask_decoder.'):
mask_decoder_keys.append(key)
for mask_decoder_key in mask_decoder_keys:
hq_decoder[mask_decoder_key.replace("mask_decoder.", "")] = sam_hq[ mask_decoder_key]
sam_hq.pop( mask_decoder_key)
torch.save(sam_hq, "pretrained_checkpoint/sam_vit_l.pth")
torch.save(hq_decoder, "pretrained_checkpoint/sam_vit_l_maskdecoder.pth")
Hi! As a result of trial and error, the training was successful when I broke it down as follows. The parameters contained in the no_need_keys are present in sam_hq_vit_l.pth, yet they seem to be unnecessary for both sam_vit_l.pth and sam_vit_l_maskdecoder.pth. If I include the no_need_keys in either of the model file dictionaries, I get an "Unexpected key(s) in state_dict" error. Conversely, when I delete the no_need_keys, the training succeeds without any errors. What kind of parameters are these no_need_keys?
no_need_keys = ["hf_token.weight", "hf_mlp.layers.0.weight", "hf_mlp.layers.0.bias", "hf_mlp.layers.1.weight", "hf_mlp.layers.1.bias", "hf_mlp.layers.2.weight", "hf_mlp.layers.2.bias", "compress_vit_feat.0.weight", "compress_vit_feat.0.bias", "compress_vit_feat.1.weight", "compress_vit_feat.1.bias", "compress_vit_feat.3.weight", "compress_vit_feat.3.bias", "embedding_encoder.0.weight", "embedding_encoder.0.bias", "embedding_encoder.1.weight", "embedding_encoder.1.bias", "embedding_encoder.3.weight", "embedding_encoder.3.bias", "embedding_maskfeature.0.weight", "embedding_maskfeature.0.bias", "embedding_maskfeature.1.weight", "embedding_maskfeature.1.bias", "embedding_maskfeature.3.weight", "embedding_maskfeature.3.bias"]
sam_hq = torch.load("pretrained_checkpoint/sam_hq_vit_l.pth", map_location=torch.device('cpu') )
sam_ckpt = {}
hq_decoder = {}
mask_decoder_keys = []
for key in sam_hq.keys():
if key.startswith('mask_decoder.') and key.replace("mask_decoder.", "") not in no_need_keys:
mask_decoder_keys.append(key)
for mask_decoder_key in mask_decoder_keys:
hq_decoder[mask_decoder_key.replace("mask_decoder.", "")] = sam_hq[ mask_decoder_key]
for no_need_key in no_need_keys:
sam_hq.pop( f"mask_decoder.{no_need_key}")
torch.save(sam_hq, "pretrained_checkpoint/sam_vit_l.pth")
torch.save(hq_decoder, "pretrained_checkpoint/sam_vit_l_maskdecoder.pth")