oneflow icon indicating copy to clipboard operation
oneflow copied to clipboard

Implement fuse conv bn qat module

Open Ldpe2G opened this issue 3 years ago • 0 comments

实现google量化感知训练文章中提到的,训练过程中融合 conv 和 bn 的模块 QatFuseConvBN

为了简化导出 onnx 的过程,这个模块内部包含了三个状态,首先除了 nn.Module 自带的 self.training 成员,还额外给该模块添加了一个 self.is_freezed 的成员,用于指定是否bn 的参数真正 融合进了 conv 的参数中,下面具体解释三个状态代表什么意思:

状态一 ,self.self.is_freezed=False, self.training = True

此时是在对模型做量化训练,训练过程中首先对该模块的输入做模拟量化,然后通过一个卷积模块得到输出,接着计算输出的均值和方差,然后动态融合bn和conv的参数,的到新的折叠了bn的conv参数,然后对该参数做模拟量化,最后将模拟量化后的参数和输入重新输入到一个卷积操作中。

状态二,self.self.is_freezed=False, self.training = False

此时和状态一唯一不同的地方时,均值和方差用的是 moving_meanmoving_var,用于在训练过程中对验证集做测试。

状态三,self.self.is_freezed=True

给用户提供了一个 helper 函数 freeze_all_qat_submodules,调用之后会将所有类型是QatFuseConvBN 的子模块中的,bn参数和conv参数融合之后,直接替换conv掉参数,然后前向过程就和变为 输入模拟量化,权值模拟量化,然后调用卷积这三个操作,不会再包含bn,而且这个过程是不可逆的,只有在训练完,转Onnx之前才这么做。

Ldpe2G avatar Jul 26 '22 03:07 Ldpe2G