brevitas
brevitas copied to clipboard
Fix (graph/bias_correction): Fix when layer parameters are offloaded to `accelerate`
Currently, if a layer doesn't have a bias, and skip_if_no_bias=False and the parameters of the current module are being offloaded with accelerate, applying bias correction fails with the following error:
Traceback (most recent call last):
File "/home/nfraser/workspace/optimum-amd/examples/quantization/brevitas/quantize_llm.py", line 161, in <module>
main(args)
File "/home/nfraser/workspace/optimum-amd/examples/quantization/brevitas/quantize_llm.py", line 65, in main
quantized_model = quantizer.quantize(qconfig, calibration_dataset)
File "/home/nfraser/.local/miniforge3/envs/20240516_oamd/lib/python3.9/site-packages/optimum/amd/brevitas/quantizer.py", line 244, in quantize
apply_bias_correction(
File "/home/nfraser/.local/miniforge3/envs/20240516_oamd/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/nfraser/.local/miniforge3/envs/20240516_oamd/lib/python3.9/site-packages/optimum/amd/brevitas/quantizer.py", line 337, in apply_bias_correction
model(**inps)
File "/home/nfraser/.local/miniforge3/envs/20240516_oamd/lib/python3.9/site-packages/brevitas/graph/calibrate.py", line 122, in __exit__
self.bias_correction.apply_correction(self.model)
File "/home/nfraser/.local/miniforge3/envs/20240516_oamd/lib/python3.9/site-packages/brevitas/graph/calibrate.py", line 292, in apply_correction
module.register_parameter(
File "/home/nfraser/.local/miniforge3/envs/20240516_oamd/lib/python3.9/site-packages/brevitas/nn/mixin/parameter.py", line 111, in register_parameter
super(QuantBiasMixin, self).register_parameter(name, value)
File "/home/nfraser/.local/miniforge3/envs/20240516_oamd/lib/python3.9/site-packages/brevitas/nn/mixin/parameter.py", line 81, in register_parameter
super(QuantWeightMixin, self).register_parameter(name, value)
File "/home/nfraser/.local/miniforge3/envs/20240516_oamd/lib/python3.9/site-packages/torch/nn/modules/module.py", line 557, in register_parameter
raise TypeError(f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "
TypeError: cannot assign 'torch.meta.FloatTensor' object to parameter 'bias' (torch.nn.Parameter or None required)
This PR resolves this issue.