lsq-net icon indicating copy to clipboard operation
lsq-net copied to clipboard

Quantized checkpoint is not smaller / faster

Open sophia1488 opened this issue 4 years ago • 8 comments

Hi, first, thanks for your implementation!
It's not too hard to apply your code to my current model! However, the dumped checkpoint has the same size as the original model, and I wonder how to store it with less storage.

Thank you and hope to get your feedback.

sophia1488 avatar Mar 16 '22 10:03 sophia1488

Hi Sophia, thank you for using my code. Because I dumped the floating-point weights (rather than the quantized ones) in the checkpoint. The quantization runs on the fly, generating quantized weights.

If you want to save quantized weights, you can declare them as Buffer (like Batch Norm), and then they will be also saved into the checkpoint. If you do so, you will find both floating-point and quantized weights in your checkpoint. Then you can remove the floating-point ones from it manually.

zhutmost avatar Mar 16 '22 16:03 zhutmost

I am not very similar with deployment tools. I guess there would be some much convenient ways to generate a really quantized model (I mean all weights are integers, but now it is fake quantized) with its pretrained weights.

zhutmost avatar Mar 16 '22 16:03 zhutmost

Hi, thanks for your quick reply! I modified the code and it worked!

For people who wanna save quantized weights, basically the modification is that the LSQ module returns x after round_pass & s_scale, and I dump these 2 parameters instead of self.weight & quan_w_fn.s of QuanConv2d. https://github.com/zhutmost/lsq-net/blob/2c24a96be06d044fa4c7d651727f4574b8d88c86/quan/quantizer/lsq.py#L56 And I implemented functions for dumping & loading the quantized checkpoint.

Finally, the saved checkpoint is reduced by more than 70%, and I could also load the checkpoint for evaluation / inference. 😀

Thanks again ~ (I could also share my implementation if needed.)

sophia1488 avatar Mar 18 '22 03:03 sophia1488

Great. Thank you, Sophia. I'm still struggling with my theis. Maybe I will consider this feature in a few weeks (and also fix some known compatibility bugs).

zhutmost avatar Mar 18 '22 05:03 zhutmost

Actually, a newer version is maintained in NeuralZip,and you can find more experiment results there. But it is developed with PyTorch-Lightning framework, rather than the raw PyTorch.

zhutmost avatar Mar 18 '22 05:03 zhutmost

Hi @zhutmost, The performance of quantized model is good! But I found that the inference time after quantization is slower, do you have any comments on this? Below is the time it takes for inference. (Note the inference time does not include checkpoint loading, and I use the same code)

  • not loading any checkpoint 21.555 s
  • loading original model 22.770 s
  • loading quantized model 35.845 s

Thanks! 🙏

sophia1488 avatar Mar 30 '22 10:03 sophia1488

I have no idea. I cannot explain it without more details.

In my code, the validation/test epoch (just like inference) is much faster than training epoch.

zhutmost avatar Mar 30 '22 16:03 zhutmost

Hi, thanks for your quick reply. Below is the code I modified from your repo. Due to the characteristics of my model, I didn't modify codes related to QuanLinear.

When dumping checkpoint,

for name, module in model.named_modules():
    if type(module) == QuanConv2d:
        # 1. save self.quantized_weight (convert to type int8), self.w_scale
        state_dict[f'{name}.quantized_weight'] = module.quantized_weight.to(torch.int8)
        state_dict[f'{name}.w_scale'] = module.w_scale
        # 2. remove quan_w_fn.s & weight, since there's no need to use them for inference.
        state_dict.pop(f'{name}.quan_w_fn.s', None)
        state_dict.pop(f'{name}.weight', None)

When loading checkpoint,

for name in quantized_weights:
    # so that QuanConv2d.weight will be updated calling model.load_state_dict
    state_dict[f'{name}.weight'] = state_dict[f'{name}.w_scale'] * state_dict[f'{name}.quantized_weight']   # back to float32
    state_dict.pop(f'{name}.w_scale', None)
    state_dict.pop(f'{name}.quantized_weight', None)

In class QuanConv2d,

class QuanConv2d(nn.Conv2d):
    def __init__(self, ..., inference=False):
        ...
        self.inference = inference
        ...
    def forward(self, x):
        # quantized input
        quantized_act, a_scale = self.quan_a_fn(x)
        act = quantized_act * a_scale
        # quantized weight
        if not self.inference:
            self.quantized_weight, self.w_scale = self.quan_w_fn(self.weight)
            weight = self.quantized_weight * self.w_scale
        else:
            weight = self.weight    # the saved quantized weight (in float32)
        return self._conv_forward(act, weight)

I don't understand why it'll get slower either :( Thank you.

sophia1488 avatar Mar 31 '22 06:03 sophia1488

Hi, thanks for your quick reply. Below is the code I modified from your repo. Due to the characteristics of my model, I didn't modify codes related to QuanLinear.

When dumping checkpoint,

for name, module in model.named_modules():
    if type(module) == QuanConv2d:
        # 1. save self.quantized_weight (convert to type int8), self.w_scale
        state_dict[f'{name}.quantized_weight'] = module.quantized_weight.to(torch.int8)
        state_dict[f'{name}.w_scale'] = module.w_scale
        # 2. remove quan_w_fn.s & weight, since there's no need to use them for inference.
        state_dict.pop(f'{name}.quan_w_fn.s', None)
        state_dict.pop(f'{name}.weight', None)

When loading checkpoint,

for name in quantized_weights:
    # so that QuanConv2d.weight will be updated calling model.load_state_dict
    state_dict[f'{name}.weight'] = state_dict[f'{name}.w_scale'] * state_dict[f'{name}.quantized_weight']   # back to float32
    state_dict.pop(f'{name}.w_scale', None)
    state_dict.pop(f'{name}.quantized_weight', None)

In class QuanConv2d,

class QuanConv2d(nn.Conv2d):
    def __init__(self, ..., inference=False):
        ...
        self.inference = inference
        ...
    def forward(self, x):
        # quantized input
        quantized_act, a_scale = self.quan_a_fn(x)
        act = quantized_act * a_scale
        # quantized weight
        if not self.inference:
            self.quantized_weight, self.w_scale = self.quan_w_fn(self.weight)
            weight = self.quantized_weight * self.w_scale
        else:
            weight = self.weight    # the saved quantized weight (in float32)
        return self._conv_forward(act, weight)

I don't understand why it'll get slower either :( Thank you.

Hi, sophia.Is your problem solved?

784582008 avatar Oct 14 '22 02:10 784582008

Honestly, it's still slow but I'm not dealing with this problem now. Thanks!

sophia1488 avatar Oct 14 '22 03:10 sophia1488

Honestly, it's still slow but I'm not dealing with this problem now. Thanks!

Hi,thanks for your quick reply.Do you have a better way now?

784582008 avatar Oct 14 '22 03:10 784582008

Hi, thanks for your quick reply! I modified the code and it worked!

For people who wanna save quantized weights, basically the modification is that the LSQ module returns x after & , and I dump these 2 parameters instead of & of .round_pass``s_scale``self.weight``quan_w_fn.s``QuanConv2d

https://github.com/zhutmost/lsq-net/blob/2c24a96be06d044fa4c7d651727f4574b8d88c86/quan/quantizer/lsq.py#L56

And I implemented functions for dumping & loading the quantized checkpoint. Finally, the saved checkpoint is reduced by more than 70%, and I could also load the checkpoint for evaluation / inference. 😀

Thanks again ~ (I could also share my implementation if needed.)

Hello,can you share this part of the code

BYFgithub avatar Sep 02 '23 12:09 BYFgithub