Do we need to set the masks for the bias in BN layers?
请问pruned_model里有对BN的bias做剪枝吗,如果是直接只对weight做掩膜的话剩下的bias还是会对网络造成很大的影响的吧?这样子还是不可以直接赋值给Compact_model的,这个问题应该怎么解决呢?
同问
剪枝的时候会先对 bias 做处理,用其后的卷积层或者 BN 层来吸收这个参数,然后再将权重赋值给 compact model,这样做的话就能保证 pruned_model (只对 weight 置 0 的模型)和 Compact_model 对相同的输入有相同的输出
请问如果用其后的卷积层吸收bias的话,因为没有BN层,所以是用来更新conv里的bias吗,但是conv的bias的维度与前一层BN的bias维度不一致的,应该怎么解决这个问题呢?
首先剩余的bias需要经过激活函数,然后跟下一层的权值相乘,之后加入mean(下一层含BN)或者bias(下一层不含BN)中
为什么不加入下一层(含BN)的bias,只加入到running_mean呢?
为什么不加入下一层(含BN)的bias,只加入到running_mean呢?
BN层没有bias
为什么不加入下一层(含BN)的bias,只加入到running_mean呢?
BN层没有bias
BN层有两个参数,一个gamma,一个是shift,我说的bias是指shift的参数,感觉应该跟running_mean一起更新才对
为什么不加入下一层(含BN)的bias,只加入到running_mean呢?
BN层没有bias
BN层有两个参数,一个gamma,一个是shift,我说的bias是指shift的参数,感觉应该跟running_mean一起更新才对 我的理解是这样的:
- 假设当前层(i)通道裁剪对应记录列表为mask, 1表示保留,0表示扔掉;当前层的shift = mask * shift + (1-mask) * shift,后面(1-mask) * shift表示剩余的shift, 记为remain_shift;
- 卷积+BN+激活操作如下:ReLU( gamma * (CONV(X) - mean) / sqrt(var) + shift )
- 根据上述公式,由于remain_shift对应部分gamma接近为0,第i层剩余部分为ReLU(remain_shift), 即下一层输入 X_full = X + ReLU(remain_shift),
- 多出的部分根据下一层是否含有BN处理,若含有BN,多余的部分可以加到mean里,或者加到shift里; 若不含BN,即该层操作变成 CONV(X) + bias, 多余的部分加到bias中
为什么不加入下一层(含BN)的bias,只加入到running_mean呢?
BN层没有bias
谢谢分析,理清了我的思路。关于第四点里的mean,它是指batch_mean还是running_mean呢?如果是running_mean,按照官方文档的设置它是用于evaluation代替training时的shift的,training时并不参与,这样子training时怎么影响下一层的数据呢?谢谢
为什么不加入下一层(含BN)的bias,只加入到running_mean呢?
BN层没有bias
谢谢分析,理清了我的思路。关于第四点里的mean,它是指batch_mean还是running_mean呢?如果是running_mean,按照官方文档的设置它是用于evaluation代替training时的shift的,training时并不参与,这样子training时怎么影响下一层的数据呢?谢谢
rolling_mean, 刚才说的操作是在训练完毕后的剪枝过程里,剪枝的时候不涉及训练
为什么不加入下一层(含BN)的bias,只加入到running_mean呢?
BN层没有bias
谢谢分析,理清了我的思路。关于第四点里的mean,它是指batch_mean还是running_mean呢?如果是running_mean,按照官方文档的设置它是用于evaluation代替training时的shift的,training时并不参与,这样子training时怎么影响下一层的数据呢?谢谢
rolling_mean, 刚才说的操作是在训练完毕后的剪枝过程里,剪枝的时候不涉及训练
明白了,感激不尽!!!
非常感谢作者的开源代码和您@lz20061213 的耐心解答,关于第4点我有一个疑惑,我看到在prune_utils.py文件的prune_model_keep_size()函数中,当下一层含有BN时,上层输出多余的部分经过卷积层后直接加到了running_mean里。请问在加到running_mean之前是不是需要求一下均值?因为running_mean里应该都是各个channel的均值。
非常感谢作者的开源代码和您@lz20061213 的耐心解答,关于第4点我有一个疑惑,我看到在prune_utils.py文件的prune_model_keep_size()函数中,当下一层含有BN时,上层输出多余的部分经过卷积层后直接加到了running_mean里。请问在加到running_mean之前是不是需要求一下均值?因为running_mean里应该都是各个channel的均值。
prune是在模型训练完后进行的,running_mean里面保存了训练时统计的均值,在测试时固定不变 (BN里面的mean和var参数训练的时候统计更新,测试时固定)
感谢您的回复,我的意思是在剪枝时,为保证 pruned_model (只对 weight 置 0 的模型)和 Compact_model 对相同的输入有相同的输出。下层 BN 层吸收上层剩余的remain_shift这个参数时,用到了BN 层里的running_mean去吸收,其直接加到了running_mean上。在测试时mean和var固定不变,这里剪枝时更改了其值,而且是不是也需要考虑对BN 层里的running_var的影响?
感谢您的回复,我的意思是在剪枝时,为保证 pruned_model (只对 weight 置 0 的模型)和 Compact_model 对相同的输入有相同的输出。下层 BN 层吸收上层剩余的remain_shift这个参数时,用到了BN 层里的running_mean去吸收,其直接加到了running_mean上。在测试时mean和var固定不变,这里剪枝时更改了其值,而且是不是也需要考虑对BN 层里的running_var的影响?
不知道是不是我的说法有问题,剪枝更改running_mean或bias目的就是为了保持计算结果不变。
我们假设就两层 第一层剪枝 第二层不剪枝
第一层操作:f(x) = CONV1+BN1+ReLU, 第二层操作g(x) = CONV2 + BN2 + ReLU
最后的输出就是g(f(x)) = ReLU(gamma2 * (CONV2(f(x)) - mean2) / sqrt(var2) + shift2 )
裁剪后呢, f(x) = f'(x) + ReLU(remain_shift1)
带入上面的式子我们可以发现,在保证第二层采用同样的计算方式和结果不变的情况下:
g(f(x)) = ReLU(gamma2 * (CONV2(f'(x)) + CONV2(ReLU(remain_shift1)) - mean2) / sqrt(var2) + shift2)
(暂时不考虑CONV2(ReLU(remain_shift1)的过程)
最简单的更改方式就是 mean2’ = mean2 - CONV2(ReLU(remain_shift1))
g(f'(x)) = ReLU(gamma2 * (CONV2(f'(x)) - mean2') / sqrt(var2) + shift2)
感谢您的回复,我的意思是在剪枝时,为保证 pruned_model (只对 weight 置 0 的模型)和 Compact_model 对相同的输入有相同的输出。下层 BN 层吸收上层剩余的remain_shift这个参数时,用到了BN 层里的running_mean去吸收,其直接加到了running_mean上。在测试时mean和var固定不变,这里剪枝时更改了其值,而且是不是也需要考虑对BN 层里的running_var的影响?
不知道是不是我的说法有问题,剪枝更改running_mean或bias目的就是为了保持计算结果不变。 我们假设就两层 第一层剪枝 第二层不剪枝 第一层操作:
f(x) = CONV1+BN1+ReLU, 第二层操作g(x) = CONV2 + BN2 + ReLU最后的输出就是g(f(x)) = ReLU(gamma2 * (CONV2(f(x)) - mean2) / sqrt(var2) + shift2 )裁剪后呢,f(x) = f'(x) + ReLU(remain_shift1)带入上面的式子我们可以发现,在保证第二层采用同样的计算方式和结果不变的情况下:g(f(x)) = ReLU(gamma2 * (CONV2(f'(x)) + CONV2(ReLU(remain_shift1)) - mean2) / sqrt(var2) + shift2)(暂时不考虑CONV2(ReLU(remain_shift1)的过程) 最简单的更改方式就是 mean2’ = mean2 - CONV2(ReLU(remain_shift1))g(f'(x)) = ReLU(gamma2 * (CONV2(f'(x)) - mean2') / sqrt(var2) + shift2)
非常感谢您的耐心解答,您的回答完美消除了我的疑问,我之前一直理解有误,万分感谢!!!