Hi,
Author, I keep getting this error after quantize the model:
Traceback (most recent call last):
File "/home/my/roma_quant/experiments/quant.py", line 145, in
test_mega1500(model, "quant")
File "/home/my/roma_quant/experiments/quant.py", line 49, in test_mega1500
mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my/roma_quant/roma/benchmarks/megadepth_pose_estimation_benchmark.py", line 73, in benchmark
dense_matches, dense_certainty = model.match(
^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/my/roma_quant/roma/models/matcher.py", line 704, in match
corresps = self.forward_symmetric(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my/roma_quant/roma/models/matcher.py", line 559, in forward_symmetric
corresps, dec_timer = self.decoder(f_q_pyramid,
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/my/roma_quant/roma/models/matcher.py", line 370, in forward
f1_s, f2_s = self.projnew_scale, self.projnew_scale
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/container.py", line 219, in forward
input = module(input)
^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/nn/qconv2d.py", line 55, in forward
return self._conv_forward(input, self.qweight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 272, in torch_function
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 324, in torch_dispatch
return qfallback(op, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/optimum/quanto/tensor/qtensor.py", line 29, in qfallback
return callable(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/lib/python3.11/site-packages/torch/ops.py", line 1061, in call
return self._op(*args, **(kwargs or {}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
Do you know what kind of setting I could add to avoid this issue?
I would run a debugger and set a breakpoint around the error. From what I remember there is some .float() around there which may cause issues if not autocasting. Do you have autocast on or off?
Yes, I find I turn off autocast in matcher.py and set torch.float32 when define the model could solve this annoying error.
I mean you should delete "with autocast" before the code block in matcher.py
Do you know how to allow amp when doing quantization?
In general I think amp doesn't particularly well with explicit casts inside. What are you trying to quantize exactly? Weights as fp16?
I am trying to do int8 quantization use quanto library, but when inference I will get error.
https://github.com/huggingface/optimum-quanto