pytorch_bn_fusion icon indicating copy to clipboard operation
pytorch_bn_fusion copied to clipboard

No effect after “ fuse_bn_recursively”

Open Ironteen opened this issue 4 years ago • 3 comments

I load mobileNet v2 and operate it by fuse_bn_recursively function, then print the network strutures of this two model, but I found that the bn_fusion net is the same as the initial net, is it because of my misoperation? ` import torch from bn_fusion import fuse_bn_recursively from pytorchcv.model_provider import get_model as ptcv_get_model

if name == 'main':

net = ptcv_get_model('mobilenetv2_w1', pretrained=True)

net1 = fuse_bn_recursively(net)
net1.eval()

net_dict1 = {}
for idx,(name,param) in enumerate(net.named_parameters()):
    net_dict1[name] = param

net_dict2 = {}
for idx,(name,param) in enumerate(net1.named_parameters()):
    net_dict2[name] = param
names = net_dict1.keys()

diff_cnt = 0
for name in names:
    if net_dict1[name].shape!=net_dict2[name].shape:
        diff_cnt +=1
print("diff params:",diff_cnt)

`

Ironteen avatar Mar 22 '20 11:03 Ironteen

Please, take a look at the resent example. You need some bells and whistles to make it work.

lext avatar Mar 22 '20 12:03 lext

Please, take a look at the resent example. You need some bells and whistles to make it work.

Thank you very much for your prompt reply. I think I understand what you mean. This project is very enlightening for me.

Ironteen avatar Mar 22 '20 12:03 Ironteen

Actually I think the bn fusion in this repo is for depthwise convolution. Not standard convolution, correct me if I am wrong. Please checking the implementation of kito. if conv_layer_type == 'Conv2D': for i in range(conv_weights.shape[-1]): conv_weights[:, :, :, i] *= A[i] elif conv_layer_type == 'Conv2DTranspose': for i in range(conv_weights.shape[-2]): conv_weights[:, :, i, :] *= A[i] elif conv_layer_type == 'DepthwiseConv2D': for i in range(conv_weights.shape[-2]): conv_weights[:, :, i, :] *= A[i] elif conv_layer_type == 'Conv3D': for i in range(conv_weights.shape[-1]): conv_weights[:, :, :, :, i] *= A[i]

jiangyuqi1017 avatar Jun 23 '20 10:06 jiangyuqi1017