yolov7 icon indicating copy to clipboard operation
yolov7 copied to clipboard

Fuse ImplicitA and Convolution

Open Eren-Corn0712 opened this issue 1 year ago • 1 comments

建堯博士您好,關於在detect head的 fuse layer有試跑過一次有跳出問題,因此針對部分有重新實作了一次。 請您確認一下是否正確。

def fuse_conv_and_ia(conv, ia):
    fusedconv = nn.Conv2d(conv.in_channels,
                          conv.out_channels,
                          kernel_size=conv.kernel_size,
                          stride=conv.stride,
                          padding=conv.padding,
                          groups=conv.groups,
                          bias=True).requires_grad_(False).to(conv.weight.device)

    # Prepare filters
    c1, c2, _, _ = conv.weight.shape
    c1_, c2_, _, _ = ia.implicit.shape

    w_conv = conv.weight.clone().reshape(c1, c2)
    b_conv = conv.bias.clone()
    w_ia = ia.implicit.clone().reshape(c2_, c1_)

    fusedconv.bias.copy_(nn.Parameter(torch.matmul(w_conv, w_ia).squeeze(1) + b_conv))

    return fusedconv


def fuse_conv_and_im(conv, im):
    fusedconv = nn.Conv2d(conv.in_channels,
                          conv.out_channels,
                          kernel_size=conv.kernel_size,
                          stride=conv.stride,
                          padding=conv.padding,
                          groups=conv.groups,
                          bias=True).requires_grad_(False).to(conv.weight.device)

    # Prepare filters
    c1, c2, _, _ = im.implicit.shape
    w_conv = conv.weight.clone()
    b_conv = conv.bias.clone()

    w1_im = im.implicit.clone().reshape(c2)
    w2_im = im.implicit.clone().transpose(0, 1)
    fusedconv.bias.copy_(nn.Parameter(b_conv * w1_im))
    fusedconv.weight.copy_(nn.Parameter(w_conv * w2_im))
    return fusedconv

Eren-Corn0712 avatar Aug 05 '22 15:08 Eren-Corn0712

New version provide more concise fuse function, you could check it.

WongKinYiu avatar Aug 06 '22 00:08 WongKinYiu