RoMa icon indicating copy to clipboard operation
RoMa copied to clipboard

Keep getting input type with weight type not match error after quantization

Open skill-diver opened this issue 1 year ago • 6 comments

Hi,

Author, I keep getting this error after quantize the model:

Traceback (most recent call last): File "/home/my/roma_quant/experiments/quant.py", line 145, in test_mega1500(model, "quant") File "/home/my/roma_quant/experiments/quant.py", line 49, in test_mega1500 mega1500_results = mega1500_benchmark.benchmark(model, model_name=name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/my/roma_quant/roma/benchmarks/megadepth_pose_estimation_benchmark.py", line 73, in benchmark dense_matches, dense_certainty = model.match( ^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/my/roma_quant/roma/models/matcher.py", line 704, in match corresps = self.forward_symmetric(batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/my/roma_quant/roma/models/matcher.py", line 559, in forward_symmetric corresps, dec_timer = self.decoder(f_q_pyramid, ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/my/roma_quant/roma/models/matcher.py", line 370, in forward f1_s, f2_s = self.projnew_scale, self.projnew_scale ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/container.py", line 219, in forward input = module(input) ^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/nn/qconv2d.py", line 55, in forward return self._conv_forward(input, self.qweight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward return F.conv2d(input, weight, bias, self.stride, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 272, in torch_function return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 324, in torch_dispatch return qfallback(op, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/tensor/qtensor.py", line 29, in qfallback return callable(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/miniconda3/lib/python3.11/site-packages/torch/ops.py", line 1061, in call return self._op(*args, **(kwargs or {})) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

Do you know what kind of setting I could add to avoid this issue?

skill-diver avatar Dec 15 '24 07:12 skill-diver

I would run a debugger and set a breakpoint around the error. From what I remember there is some .float() around there which may cause issues if not autocasting. Do you have autocast on or off?

Parskatt avatar Dec 15 '24 07:12 Parskatt

Yes, I find I turn off autocast in matcher.py and set torch.float32 when define the model could solve this annoying error.

skill-diver avatar Dec 15 '24 09:12 skill-diver

I mean you should delete "with autocast" before the code block in matcher.py

skill-diver avatar Dec 15 '24 10:12 skill-diver

Do you know how to allow amp when doing quantization?

skill-diver avatar Dec 24 '24 06:12 skill-diver

In general I think amp doesn't particularly well with explicit casts inside. What are you trying to quantize exactly? Weights as fp16?

Parskatt avatar Dec 24 '24 09:12 Parskatt

I am trying to do int8 quantization use quanto library, but when inference I will get error. https://github.com/huggingface/optimum-quanto

skill-diver avatar Dec 24 '24 19:12 skill-diver