Is there a way to speed up collect_stats() method in pythorch quantization?
I am calibrating a model for TRT inference on a server using pytorch-quantization and my calibration sample is quite big (>1 million samples), and collect_stats method is turning out to be slow for fast experimentation, is there a way to speed up this method?
I'm using pytorch-quantization v2.1.0
Environment
TensorRT Version: 8.0.2 NVIDIA GPU: Titan x Pascal NVIDIA Driver Version: 470 CUDA Version: 11.4 CUDNN Version: N/A Operating System: Ubuntu 18 Python Version (if applicable): N/A Tensorflow Version (if applicable): N/A PyTorch Version (if applicable): 1.9.1 Baremetal or Container (if so, version): N/A
@ttyio ^ ^
@gj-raza currently the quantization is implemented by a sequence of pytorch op, and this can be accelerated by using cuda extension. I will create internal feature request for this, thanks!
We will enable the cuda extension by default in next release, closing this issue, thanks!