yolov7
yolov7 copied to clipboard
Fuse ImplicitA and Convolution
建堯博士您好,關於在detect head的 fuse layer有試跑過一次有跳出問題,因此針對部分有重新實作了一次。 請您確認一下是否正確。
def fuse_conv_and_ia(conv, ia):
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)
# Prepare filters
c1, c2, _, _ = conv.weight.shape
c1_, c2_, _, _ = ia.implicit.shape
w_conv = conv.weight.clone().reshape(c1, c2)
b_conv = conv.bias.clone()
w_ia = ia.implicit.clone().reshape(c2_, c1_)
fusedconv.bias.copy_(nn.Parameter(torch.matmul(w_conv, w_ia).squeeze(1) + b_conv))
return fusedconv
def fuse_conv_and_im(conv, im):
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)
# Prepare filters
c1, c2, _, _ = im.implicit.shape
w_conv = conv.weight.clone()
b_conv = conv.bias.clone()
w1_im = im.implicit.clone().reshape(c2)
w2_im = im.implicit.clone().transpose(0, 1)
fusedconv.bias.copy_(nn.Parameter(b_conv * w1_im))
fusedconv.weight.copy_(nn.Parameter(w_conv * w2_im))
return fusedconv
New version provide more concise fuse function, you could check it.