EdgeSAM icon indicating copy to clipboard operation
EdgeSAM copied to clipboard

About phase2 training 's pretrain weights

Open andrecmwang opened this issue 1 year ago • 8 comments

Hi,

Where can i get "sam_vit_h_prompt_encoder.pth" and "sam_vit_h_mask_decoder.pth" for train phase2 ?

Regards

andrecmwang avatar Aug 20 '24 06:08 andrecmwang

I have the same qustion

Jaaaahan avatar Aug 26 '24 10:08 Jaaaahan

I have the same qustion

sam = sam_model_registry["vit_h"](checkpoint=src_path)
torch.save(sam.prompt_encoder.state_dict(),'weights/sam_vit_h_prompt_encoder.pth')
torch.save(sam.mask_decoder.state_dict(),'weights/sam_vit_h_mask_decoder.pth')

AmorFati2016 avatar Sep 18 '24 05:09 AmorFati2016

I have the same qustion

sam = sam_model_registry["vit_h"](checkpoint=src_path)
torch.save(sam.prompt_encoder.state_dict(),'weights/sam_vit_h_prompt_encoder.pth')
torch.save(sam.mask_decoder.state_dict(),'weights/sam_vit_h_mask_decoder.pth')

Oh, thank you. Another question, in phase 1, where can i find repvit_m1_distill_300.pth.

VictorZoo avatar Sep 18 '24 08:09 VictorZoo

I have the same qustion

sam = sam_model_registry["vit_h"](checkpoint=src_path)
torch.save(sam.prompt_encoder.state_dict(),'weights/sam_vit_h_prompt_encoder.pth')
torch.save(sam.mask_decoder.state_dict(),'weights/sam_vit_h_mask_decoder.pth')

Oh, thank you. Another question, in phase 1, where can i find repvit_m1_distill_300.pth.

https://github.com/THU-MIG/RepViT https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_300e.pth

AmorFati2016 avatar Sep 18 '24 10:09 AmorFati2016

thank you

Jaaaahan avatar Sep 19 '24 10:09 Jaaaahan

我也有同样的问题

sam = sam_model_registry["vit_h"](checkpoint=src_path)
torch.save(sam.prompt_encoder.state_dict(),'weights/sam_vit_h_prompt_encoder.pth')
torch.save(sam.mask_decoder.state_dict(),'weights/sam_vit_h_mask_decoder.pth')

I also encountered this problem, may I ask how to use these three lines of code? Can you be more specific? Excuse me, I'm new here. Thank you

wowangle97 avatar Dec 11 '24 06:12 wowangle97

I have the same qustion

sam = sam_model_registry["vit_h"](checkpoint=src_path)
torch.save(sam.prompt_encoder.state_dict(),'weights/sam_vit_h_prompt_encoder.pth')
torch.save(sam.mask_decoder.state_dict(),'weights/sam_vit_h_mask_decoder.pth')

Oh, thank you. Another question, in phase 1, where can i find repvit_m1_distill_300.pth.

How do you use these three lines of code? thank you

wowangle97 avatar Dec 11 '24 06:12 wowangle97

I have the same qustion import torch

加载原始权重

sam_weight = torch.load('weights/sam_vit_h_4b8939.pth') key_word = 'prompt_encoder' # 或 'mask_decoder'

new_weight = {} prefix = f"{key_word}."

for key in sam_weight.keys(): # 检查键是否包含目标组件前缀 if key.startswith(prefix): # 移除组件前缀,保留内部结构 new_key = key[len(prefix):] new_weight[new_key] = sam_weight[key] elif key == key_word: # 处理特殊情况:键名完全等于组件名 new_weight = sam_weight[key] break

保存处理后的权重

torch.save(new_weight, f'weights/sam_vit_h_{key_word}.pth')

qinfendekaizhou avatar Aug 05 '25 12:08 qinfendekaizhou