EdgeSAM icon indicating copy to clipboard operation
EdgeSAM copied to clipboard

weight/repvit_m1_distill_300.th

Open gold123fish opened this issue 1 year ago • 1 comments

hello The author uses PRETRAINED: weights/repvit_m1_distill_300.thm in training/config/rep_it_m1-fuse_cas-distill.yaml But I haven't found exactly the same weight from here https://github.com/THU-MIG/RepViT/releases I downloaded a repvit_m0_9_istill_300e.pth At the final val, it will remind me that missing-keys=['imageencoder. features. 0.0. c. weight ',' imageeencoder. features. 0.0. bn. weight ',' imageeencoder. features. 0.0. bn. bias',..., the miou obtained is not very ideal. Have you encountered the same problem? How should it be resolved? thank you!!!!!

2222L

gold123fish avatar Dec 25 '24 13:12 gold123fish

Change RepVGGDW to the following implementation

class RepVGGDW(torch.nn.Module):
    def __init__(self, ed) -> None:
        super().__init__()
        self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
        self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
        self.bn = torch.nn.BatchNorm2d(ed)
        self.dim = ed
    
    def forward(self, x):
        return self.bn((self.conv(x) + self.conv1(x)) + x)
    
    @torch.no_grad()
    def fuse(self):
        conv = self.conv.fuse()
        conv1 = self.conv1
        
        conv_w = conv.weight
        conv_b = conv.bias
        conv1_w = conv1.weight
        conv1_b = conv1.bias
        
        conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])

        identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])

        final_conv_w = conv_w + conv1_w + identity
        final_conv_b = conv_b + conv1_b

        conv.weight.data.copy_(final_conv_w)
        conv.bias.data.copy_(final_conv_b)

        bn = self.bn
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = conv.weight * w[:, None, None, None]
        b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        conv.weight.data.copy_(w)
        conv.bias.data.copy_(b)
        return conv

Note: the most elegant way would be to adjust the keys of repvit_m0_9_istill_300e.pth to match the EdgeSAM

huaibovip avatar May 10 '25 04:05 huaibovip