brevitas
brevitas copied to clipboard
Example: How to use merge_bn correctly
There is an architecture I would like to quantise and retrain from its floating point counterpart. I would like to incorporate the merge_bn operation supported by Brevitas. How exactly would I do this here. An overview is good but some code would be better. Note I only want to merge/fuse the Conv + BN + ReLU components. Here is my architecture:
class QuantizedModel(nn.Module): def init(self, config): super(QuantizedVGG, self).init()
self.weight_config = config
k = 1
self.quant_inp = qnn.QuantIdentity(
bit_width=16, return_quant_tensor=True)
self.conv1 = qnn.QuantConv2d(in_channels=3, out_channels=int(k * 128), kernel_size=3, padding=1, weight_bit_width=self.weight_config[0], return_quant_tensor=True, bias=True)
self.bn1 = nn.BatchNorm2d(int(k * 128))
self.relu1 = qnn.QuantReLU(bit_width=self.weight_config[0], return_quant_tensor=True)
self.conv2 = qnn.QuantConv2d(int(k * 128), int(k * 128), kernel_size=3, padding=1, weight_bit_width=self.weight_config[1], return_quant_tensor=True, bias=True)
self.bn2 = nn.BatchNorm2d(int(k * 128))
self.relu2 = qnn.QuantReLU(bit_width=self.weight_config[1], return_quant_tensor=True)
self.max_pool1 = qnn.QuantMaxPool2d(kernel_size=2, stride=2, return_quant_tensor=True)
self.conv3 = qnn.QuantConv2d(int(k * 128), int(k * 256), kernel_size=3, padding=1, weight_bit_width=self.weight_config[2], return_quant_tensor=True, bias=True)
self.bn3 = nn.BatchNorm2d(int(k * 256))
self.relu3 = qnn.QuantReLU(bit_width=self.weight_config[2], return_quant_tensor=True)
self.conv4 = qnn.QuantConv2d(int(k * 256), int(k * 256), kernel_size=3, padding=1, weight_bit_width=self.weight_config[3], return_quant_tensor=True, bias=True)
self.bn4 = nn.BatchNorm2d(int(k * 256))
self.relu4 = qnn.QuantReLU(bit_width=self.weight_config[3], return_quant_tensor=True)
self.max_pool2 = qnn.QuantMaxPool2d(kernel_size=2, stride=2, return_quant_tensor=True)
self.conv5 = qnn.QuantConv2d(int(k * 256), int(k * 512), kernel_size=3, padding=1, weight_bit_width=self.weight_config[4], return_quant_tensor=True, bias=True)
self.bn5 = nn.BatchNorm2d(int(k * 512))
self.relu5 = qnn.QuantReLU(bit_width=self.weight_config[4], return_quant_tensor=True)
self.conv6 = qnn.QuantConv2d(int(k * 512), int(k * 512), kernel_size=3, padding=1, weight_bit_width=self.weight_config[5], return_quant_tensor=True, bias=True)
self.bn6 = nn.BatchNorm2d(int(k * 512))
self.relu6 = qnn.QuantReLU(bit_width=self.weight_config[5], return_quant_tensor=True)
self.max_pool3 = qnn.QuantMaxPool2d(kernel_size=2, stride=2, return_quant_tensor=True)
input_feats = 8192
self.fc1 = qnn.QuantLinear(input_feats, int(k * 1024), weight_bit_width=self.weight_config[6], return_quant_tensor=True, bias=True)
self.fc2 = qnn.QuantLinear(int(k * 1024), 10, weight_bit_width=self.weight_config[7], bias=True)
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.max_pool1(out)
out = self.relu3(self.bn3(self.conv3(out)))
out = self.relu4(self.bn4(self.conv4(out)))
out = self.max_pool2(out)
out = self.relu5(self.bn5(self.conv5(out)))
out = self.relu6(self.bn6(self.conv6(out)))
out = self.max_pool3(out)
out = out.reshape(out.shape[0], -1)
out = self.fc1(out)
out = self.fc2(out)
return out
Hi @g12bftd, As far as I understand merging batch normalization layers is usually a post-training optimization so I would train the model and then create a script that defines two objects of the same model one with the batch norm and one without. Then I would loop over the model with batch norm merging conv & batch_norm layers and then saving the results in the new model --the one with no batch norm layer-- so the code should look roughly like follows:
bn_model = QuantizedModel(bn=True)
model = QuantizedModel(bn=False)
for l in bn_model:
if l isinstanceof(qnn.quantconv):
merge_bn(l,nextlayer)
model[index_of_corresponding_layer].copy_state_dict(l)
torch.save(model, fused_QuantizedModel.pth)
Also, I would recommend defining the model using nn.sequential
to make it easier to loop over the model
@MohamedA95 I am new to brevitas so is it that we need to train with the classical BN layers? If you could elaborate to a newbie such as me, as all the models that I am trying to export need to have intermediate BN layers.
Hi @wilfredkisku, What do you mean by classical BN layers? do you mean torch.nn.BatchNorm2d
if this is what you mean then if your model requires it you will have to use it. AFAIK brevitas does not have a quant bn yet. do you can use torch bn and then fuse it with the previous conv as both layers are linear transformations check this link
@MohamedA95 Thank you for the reply. Yes, the models that I am using requires torch.nn.BatchNorm2d
layers. Can they also be fused with quantized layers? Thanks again.
Yes they can be fused brevitas even has a function to do it under brevitas.nn.utils
@MohamedA95 Thanks for all the help. I have been able to understand the idea behind fusing the layers. What I have done now is create two models that are similar but one with CONV + BN and the other without BN.
###########################
#### MODEL 1 ##############
###########################
from torch.nn import Module
import torch.nn.functional as F
import torch.nn as nn
import brevitas.nn as qnn
from brevitas.quant import Int8Bias as BiasQuant
class QuantWeightActLeNet(Module):
def __init__(self):
super(QuantWeightActLeNet, self).__init__()
self.quant_inp = qnn.QuantIdentity(bit_width=4)
self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4)
self.relu1 = qnn.QuantReLU(bit_width=4)
self.bn = nn.BatchNorm2d(6)
def forward(self, x):
out = self.quant_inp(x)
out = self.relu1(self.bn(self.conv1(out)))
return out
###########################
#### MODEL 2 ##############
###########################
class QuantWeightActLeNet_wo(Module):
def __init__(self):
super(QuantWeightActLeNet_wo, self).__init__()
self.quant_inp = qnn.QuantIdentity(bit_width=4)
self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=4)
self.relu1 = qnn.QuantReLU(bit_width=4)
def forward(self, x):
out = self.quant_inp(x)
out = self.relu1(self.conv1(out))
return out
quant_weight_act_lenet_wo = QuantWeightActLeNet_wo()
quant_weight_act_lenet = QuantWeightActLeNet()
I am using the merge_bn
functions to merge the CONV and BN layer:
#######################
###### MERGE ##########
#######################
def merge_bn(layer, bn, output_channel_dim=0):
out = mul_add_from_bn(
bn_mean=bn.running_mean,
bn_var=bn.running_var,
bn_eps=bn.eps,
bn_weight=bn.weight.data.clone(),
bn_bias=bn.bias.data.clone())
mul_factor, add_factor = out
#compute the shape of the channel
out_ch_weight_shape = compute_channel_view_shape(layer.weight, output_channel_dim)
#in-place operations multiply the layer weights with the mul_factor of the BN
#without making a new copy of the Tensor
layer.weight.data.mul_(mul_factor.view(out_ch_weight_shape))
#handle if -> bias = True
if layer.bias is not None:
out_ch_bias_shape = compute_channel_view_shape(layer.bias, channel_dim=0)
layer.bias.data.mul_(mul_factor.view(out_ch_bias_shape))
layer.bias.data.add_(add_factor.view(out_ch_bias_shape))
else:
layer.bias = Parameter(add_factor)
if (hasattr(layer, 'weight_quant') and
isinstance(layer.weight_quant, WeightQuantProxyFromInjector)):
layer.weight_quant.init_tensor_quant()
if (hasattr(layer, 'bias_quant') and isinstance(layer.bias_quant, BiasQuantProxyFromInjector)):
layer.bias_quant.init_tensor_quant()
But I am having issues while copying the trained weights and biases + additional quantization parameres that are present in the Quantization layers such as QuantConv2d. If I use a concize code like the one below for creating the dictionary of weights for only CONV and skipping BN (which has been fused with the CONV earlier.
for keys in pretrained_dict.keys():
if keys.split('.')[0] != 'bn':
processed_dict[keys] = pretrained_dict[keys]
quant_weight_act_lenet_wo.load_state_dict(processed_dict, strict=False)
I am able to copy the weights but a few parameters associated with the brevitas
quantization library do not get copied. The error is given below:
_IncompatibleKeys(missing_keys=['quant_inp.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value', 'relu1.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value'], unexpected_keys=[])
I would be thankful for any help in this regard. Thanks again.
Hi @wilfredkisku,
I am not sure about your way of copying the state dict, I would do something like the following:
1-Define the two models one with the batch norm and one without
2-Loop over the model with bn fusing it with conv
3-Loop over the model without batch norm copying the state dict from the other model
quant_weight_act_lenet_wo.conv1.load_state_dict(quant_weight_act_lenet.conv1.state_dict())
Hi @wilfredkisku, What do you mean by classical BN layers? do you mean
torch.nn.BatchNorm2d
if this is what you mean then if your model requires it you will have to use it. AFAIK brevitas does not have a quant bn yet. do you can use torch bn and then fuse it with the previous conv as both layers are linear transformations check this link
Hi, this reply mentioned that we are able to train fixed-point batchnorm using BatchNorm2dToQuantScaleBias with power of two scale factors. I think it is supporting quant bn?
However it is not clear to me that how it is done for batchnorm? Does the scale and bias change during training? Or it is indeed doing post training for batch norm?