About phase2 training 's pretrain weights
Hi,
Where can i get "sam_vit_h_prompt_encoder.pth" and "sam_vit_h_mask_decoder.pth" for train phase2 ?
Regards
I have the same qustion
I have the same qustion
sam = sam_model_registry["vit_h"](checkpoint=src_path)
torch.save(sam.prompt_encoder.state_dict(),'weights/sam_vit_h_prompt_encoder.pth')
torch.save(sam.mask_decoder.state_dict(),'weights/sam_vit_h_mask_decoder.pth')
I have the same qustion
sam = sam_model_registry["vit_h"](checkpoint=src_path) torch.save(sam.prompt_encoder.state_dict(),'weights/sam_vit_h_prompt_encoder.pth') torch.save(sam.mask_decoder.state_dict(),'weights/sam_vit_h_mask_decoder.pth')
Oh, thank you. Another question, in phase 1, where can i find repvit_m1_distill_300.pth.
I have the same qustion
sam = sam_model_registry["vit_h"](checkpoint=src_path) torch.save(sam.prompt_encoder.state_dict(),'weights/sam_vit_h_prompt_encoder.pth') torch.save(sam.mask_decoder.state_dict(),'weights/sam_vit_h_mask_decoder.pth')Oh, thank you. Another question, in phase 1, where can i find repvit_m1_distill_300.pth.
https://github.com/THU-MIG/RepViT https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_300e.pth
thank you
我也有同样的问题
sam = sam_model_registry["vit_h"](checkpoint=src_path) torch.save(sam.prompt_encoder.state_dict(),'weights/sam_vit_h_prompt_encoder.pth') torch.save(sam.mask_decoder.state_dict(),'weights/sam_vit_h_mask_decoder.pth')
I also encountered this problem, may I ask how to use these three lines of code? Can you be more specific? Excuse me, I'm new here. Thank you
I have the same qustion
sam = sam_model_registry["vit_h"](checkpoint=src_path) torch.save(sam.prompt_encoder.state_dict(),'weights/sam_vit_h_prompt_encoder.pth') torch.save(sam.mask_decoder.state_dict(),'weights/sam_vit_h_mask_decoder.pth')Oh, thank you. Another question, in phase 1, where can i find repvit_m1_distill_300.pth.
How do you use these three lines of code? thank you
I have the same qustion import torch
加载原始权重
sam_weight = torch.load('weights/sam_vit_h_4b8939.pth') key_word = 'prompt_encoder' # 或 'mask_decoder'
new_weight = {} prefix = f"{key_word}."
for key in sam_weight.keys(): # 检查键是否包含目标组件前缀 if key.startswith(prefix): # 移除组件前缀,保留内部结构 new_key = key[len(prefix):] new_weight[new_key] = sam_weight[key] elif key == key_word: # 处理特殊情况:键名完全等于组件名 new_weight = sam_weight[key] break
保存处理后的权重
torch.save(new_weight, f'weights/sam_vit_h_{key_word}.pth')