filter-grafting
filter-grafting copied to clipboard
你好,使用VGG模型运行grafting_cifar时,出现错误
models的vgg代码中构造网络时使用了 _make_layers构造卷积层,在运行grafting.py时会出现UnboundLocalError: local variable 'w' referenced before assignment
你在函数开始的时候给w一个初始值就好了
可能不行,代码中根据‘conv’ 字符串查找卷积层,但是这里vgg用了_make_layers将参数层加入到features序列中 ,导致key中没有‘conv’, 只有‘features’ ··· def grafting(net, epoch): …… for i, (key, u) in enumerate(net.state_dict().items()): if 'conv' in key: w = round(args.a / np.pi * np.arctan(args.c * (entropy(u) - entropy(checkpoint[key]))) + 0.5, 2) model[key] = u * w + checkpoint[key] * (1 - w) net.load_state_dict(model) ···
哦哦,那这里需要改一下判断语句了,比如:if len(n.shape)==4