edgeai-torchvision icon indicating copy to clipboard operation
edgeai-torchvision copied to clipboard

Quantized Checkpoints have Floating-Point Weights

Open IsidoraR opened this issue 3 years ago • 20 comments
trafficstars

🐛 Describe the bug

Hello,

I'm using the QuantTrainModule to train a MobileNetV2 model (using the MobileNetV2 class in this repo), and the quantized checkpoints have 32-bit floating-point weights (as shown in the attached text file). Shouldn't the quantized checkpoints have 16-bit or 8-bit weights?

Python Code for Checking Weights Datatype: models_path = '/home/iradovan/edgeai-torchvision/data/checkpoints/mobV2_torch_QAT/2022-01-27_12-49-15_mobV2_edgai_torch/saved_models' saved_model_name = 'MobileNetV2_checkpoint_quantized.pth' saved_model_path = os.path.join(models_path, saved_model_name) checkpoint = torch.load(saved_model_path) model = MobileNetV2() model.load_state_dict(checkpoint['state_dict'], strict=False)

quant_state_dict = checkpoint['state_dict'] for param_quant in quant_state_dict: print(param_quant, "\t", quant_state_dict[param_quant].dtype)

Thank you, Isidora

MODEL_CKPT_DTYPE_INFO.txt

Versions

Collecting environment information... PyTorch version: 1.10.0 Is debug build: False CUDA used to build PyTorch: 10.2 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64) GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0 Clang version: Could not collect CMake version: version 3.16.3 Libc version: glibc-2.31

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-5.13.0-27-generic-x86_64-with-glibc2.17 Is CUDA available: True CUDA runtime version: 10.1.243 GPU models and configuration: GPU 0: Quadro RTX 3000 Nvidia driver version: 470.86 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A

Versions of relevant libraries: [pip3] numpy==1.21.4 [pip3] onnx-pytorch==0.1.4 [pip3] onnx2pytorch==0.4.1 [pip3] torch==1.10.0+cu111 [pip3] torchaudio==0.10.0 [pip3] torchinfo==1.5.4 [pip3] torchsummary==1.5.1 [pip3] torchvision==0.11.0a0+9e3ecf2 [conda] _pytorch_select 0.1 cpu_0
[conda] _tflow_select 2.3.0 mkl
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] ffmpeg 4.3 hf484d3e_0 pytorch [conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.21.4 pypi_0 pypi [conda] numpy-base 1.21.2 py38h79a1101_0
[conda] onnx-pytorch 0.1.4 pypi_0 pypi [conda] onnx2pytorch 0.4.1 pypi_0 pypi [conda] pytorch 1.10.0 py3.8_cuda10.2_cudnn7.6.5_0 pytorch [conda] pytorch-mutex 1.0 cuda pytorch [conda] tensorflow 2.4.1 mkl_py38hb2083e0_0
[conda] tensorflow-base 2.4.1 mkl_py38h43e0292_0
[conda] torch 1.10.0+cu111 pypi_0 pypi [conda] torchaudio 0.10.0 py38_cu102 pytorch [conda] torchinfo 1.5.4 pypi_0 pypi [conda] torchsummary 1.5.1 pypi_0 pypi [conda] torchvision 0.11.1 pypi_0 pypi

IsidoraR avatar Jan 31 '22 21:01 IsidoraR

When doing QAT with xnn.quantize in edgeai-torchvision, weights are discretized (internally) and compensated for the accuracy loss due to that discretization. Although, the weights saved in the checkpoint are floating point, it has been compensated for the error that will occur when quantized.

When TIDL reads the QAT model, it would apply the same kind of discretization for quantization. Please try the QAT model in TIDL or OSRT (for example ONNXRuntime with TIDL offload) - it should give good accuracy.

mathmanu avatar Feb 01 '22 05:02 mathmanu

When I try to compile the quantized MobileNetV2 ONNX model (using the code from custom-model-onnx Jupyter notebook) in the EdgeAI Cloud, why does the kernel die?

Code for Compiling ONNX Model: calib_images = ['sample-images/Unit_test_image.pgm'] output_dir = 'custom-artifacts/onnx/mobV2' onnx_model_path = 'onnx/MobileNetV2_checkpoint_quantized_2_best.onnx'

compile_options = { 'tidl_tools_path' : os.environ['TIDL_TOOLS_PATH'], 'artifacts_folder' : output_dir, 'tensor_bits' : 8, 'accuracy_level' : 1, 'advanced_options:calibration_frames' : len(calib_images), 'advanced_options:calibration_iterations' : 3 # used if accuracy_level = 1 }

os.makedirs(output_dir, exist_ok=True) for root, dirs, files in os.walk(output_dir, topdown=False): [os.remove(os.path.join(root, f)) for f in files] [os.rmdir(os.path.join(root, d)) for d in dirs]

so = rt.SessionOptions() EP_list = ['TIDLCompilationProvider','CPUExecutionProvider'] sess = rt.InferenceSession(onnx_model_path ,providers=EP_list, provider_options=[compile_options, {}], sess_options=so)

input_details = sess.get_inputs()

for num in tqdm.trange(len(calib_images)): output = list(sess.run(None, {input_details[0].name : preprocess_for_onnx_deeptad(calib_images[num])}))[0]

IsidoraR avatar Feb 02 '22 14:02 IsidoraR

The following is for getting best accuracy - do not know why this crash occurs - but try it anyway:

For QAT models, for best accuracy the compile options should have: 'accuracy_level': 0 'advanced_options:quantization_scale_type': 1

You can see more details here: https://github.com/TexasInstruments/edgeai-torchvision/blob/master/docs/pixel2pixel/Quantization.md

And here: https://github.com/TexasInstruments/edgeai-benchmark/blob/master/jai_benchmark/config_settings.py#L143 https://github.com/TexasInstruments/edgeai-benchmark/blob/master/jai_benchmark/config_settings.py#L157

mathmanu avatar Feb 02 '22 15:02 mathmanu

I set the compile options to the values listed above, and the kernel still dies when I try to compile the quantized MobV2 model. I also tried compiling the trained floating-point MobV2 model (after setting 'accuracy_level' to 1 and 'advanced_options:quantization_scale_type' to 0), and the kernel also dies when I try to compile the floating-point MobV2 model.

Also, I see that in the 'EP_List', there are two providers: 'TIDLCompilationProvider' and 'CPUExecutionProvider'. Are both of these providers always required for an onnx runtime inference session?

IsidoraR avatar Feb 02 '22 16:02 IsidoraR

As a debugging step, you can try to run purely on ARM without TIDL Offload: https://github.com/TexasInstruments/edgeai-tidl-tools/blob/master/examples/osrt_python/ort/onnxrt_ep.py#L154 That will be slow, but let's see if the kernel crash is still there.

mathmanu avatar Feb 02 '22 16:02 mathmanu

I can compile and run the quantized MobV2 model on ARM without TIDL Offload. However, when I try to compile the same model with TIDL Offload, the kernel dies.

IsidoraR avatar Feb 02 '22 17:02 IsidoraR

Is it possible for you to share an ONNX model that doesn't work. May be the original floating point model.

mathmanu avatar Feb 03 '22 06:02 mathmanu

When I try to attach the ONNX model to this message, I get an error that says "We don't support that file type". Could I send you an email with the ONNX model?

IsidoraR avatar Feb 03 '22 14:02 IsidoraR

Can you try to zip it before trying to upload here. If that doesn't work, you can share by email.

mathmanu avatar Feb 03 '22 14:02 mathmanu

MobileNetV2_checkpoint_quantized_2_best.zip Ok, I attached a zip file with the quantized MobV2 ONNX model to this message.

IsidoraR avatar Feb 03 '22 14:02 IsidoraR

Can you attach the original floating point model as well (not QAT).

mathmanu avatar Feb 03 '22 14:02 mathmanu

Yes, the original floating point model is attached to this message. MobileNetV2_checkpoint_23_best.zip

IsidoraR avatar Feb 03 '22 14:02 IsidoraR

I also tried replacing the fully connected layer (1x11520) with a 2D convolution layer (1x1280x3x3), and this version of the model can run only on the ARM without TIDL Offload. The kernel also dies when I try to compile this version of the model with TIDL Offload. I've attached the modified ONNX model to this message. MobileNetV2_checkpoint_quantized_2_tidl.zip

IsidoraR avatar Feb 03 '22 16:02 IsidoraR

Please run shape inference on the onnx model as shown here: https://github.com/TexasInstruments/edgeai-tidl-tools/blob/master/examples/osrt_python/common_utils.py#L168

After that it will work.

Shape inference is now required to run onnx models on our onnxruntime with tidl offload.

mathmanu avatar Feb 04 '22 08:02 mathmanu

After using shape inference, I can compile and run both versions of the quantized MobV2 ONNX model with TIDL Offload.

IsidoraR avatar Feb 04 '22 15:02 IsidoraR

I'm trying to run the quantized MobV2 model with emulated TIDL offload locally on my computer, but I'm getting an error that the libvx_tidl_rt.so cannot be opened even though that file is in the TIDL tools folder, and I exported the TIDL tools path:

(benchmark) [email protected]@UGC1GBH16G3:~/edgeai-torchvision$ echo $TIDL_TOOLS_PATH /home/iradovan/tidl_tools

(benchmark) [email protected]@UGC1GBH16G3:~$ find /home/iradovan/tidl_tools/ -name libvx_tidl_rt.so /home/iradovan/tidl_tools/libvx_tidl_rt.so

(benchmark) [email protected]@UGC1GBH16G3:~/edgeai-torchvision$ python ./references/edgeailite/scripts/eval_mobV2_onnx.py Error - libvx_tidl_rt.so: cannot open shared object file: No such file or directory python: tidl_onnxRtImport_EP.cpp:197: bool TIDL_populateOptions(std::vector<std::pair<std::__cxx11::basic_string, std::_cxx11::basic_string > >): Assertion `data->infer_ops.lib' failed. Aborted

The error is thrown by the last line of code (shown below) that initializes the inference session: compile_options = { 'tidl_tools_path' : os.environ['TIDL_TOOLS_PATH'], 'artifacts_folder' : output_dir, 'tensor_bits' : 8, 'accuracy_level' : 0, 'advanced_options:quantization_scale_type': 1, 'advanced_options:calibration_frames' : 5 #'advanced_options:calibration_iterations' : 3 # used if accuracy_level = 1 } so = rt.SessionOptions() EP_list = ['TIDLCompilationProvider','CPUExecutionProvider'] sess = rt.InferenceSession(onnx_model_path ,providers=EP_list, provider_options=[compile_options, {}], sess_options=so)

Should I use different compile options or different ONNX runtime session options if I want to run the quantized MobV2 model with emulated TIDL Offload on my computer?

IsidoraR avatar Feb 07 '22 16:02 IsidoraR

I copied the following three files from tidl_tools into edgeai-torchvision: libtidl_onnxrt_EP.so, libvx_tidl_rt.so, and libvx_tidl_rt.so.1.0. Now I can run the quantized MobV2 model with emulated TIDL Offload on my computer.

IsidoraR avatar Feb 07 '22 17:02 IsidoraR

You can try setting LD_LIBRARY_PATH to avoid that copy: https://github.com/TexasInstruments/edgeai-tidl-tools/blob/master/setup.sh#L115

mathmanu avatar Feb 07 '22 17:02 mathmanu

Yes, that resolves the error. I forgot to export the LD_LIBRARY_PATH.

IsidoraR avatar Feb 07 '22 20:02 IsidoraR

When doing QAT with xnn.quantize in edgeai-torchvision, weights are discretized (internally) and compensated for the accuracy loss due to that discretization. Although, the weights saved in the checkpoint are floating point, it has been compensated for the error that will occur when quantized.

When TIDL reads the QAT model, it would apply the same kind of discretization for quantization. Please try the QAT model in TIDL or OSRT (for example ONNXRuntime with TIDL offload) - it should give good accuracy.

if i want to convert the floating point model(pt or onnx) to int8 model after doing QAT with xnn.quantize, how can i do that?

lambdayin avatar Jul 29 '22 11:07 lambdayin