Fix Quantization Aware Training for BEIT, Eva, and SwinTransformerV2
Dear all,
When trying to perform Quantization Aware Training (QAT), modules are being wrapped with a QuantWrapper.
But, because some models are implementing qkv with biases using torch.nn.functional, one has to call self.qkv.weights.
During QAT, self.qkv.weights becomes undefined, as in the error below:
traceback
Traceback (most recent call last):
File "/workspace/main.py", line 127, in main
trainer.fit(model, data_module)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
call._call_and_handle_interrupt(
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 989, in _run
results = self._run_stage()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1035, in _run_stage
self.fit_loop.run()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 202, in run
self.advance()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 359, in advance
self.epoch_loop.run(self._data_fetcher)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 136, in run
self.advance(data_fetcher)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 240, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 187, in run
self._optimizer_step(batch_idx, closure)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 265, in _optimizer_step
call._call_lightning_module_hook(
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 157, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1282, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 151, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 230, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/amp.py", line 77, in optimizer_step
closure_result = closure()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 140, in __call__
self._result = self.closure(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 126, in closure
step_output = self._step_fn()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 315, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 382, in training_step
return self.lightning_module.training_step(*args, **kwargs)
File "/workspace/fringuantai/model.py", line 611, in training_step
y_hat, confidence, y_hat_ctx = self(inputs, context=context)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/fringuantai/model.py", line 594, in forward
return self.model(x, context=context, target_context=target_context)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/fringuantai/model.py", line 486, in forward
features = self.backbone(x)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/timm/models/eva.py", line 580, in forward
x = self.forward_features(x)
File "/opt/conda/lib/python3.10/site-packages/timm/models/eva.py", line 568, in forward_features
x = blk(x, rope=rot_pos_embed)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/timm/models/eva.py", line 236, in forward
x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/timm/models/eva.py", line 113, in forward
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'QuantWrapper' object has no attribute 'weight'
The fix is pretty straightforward: just remove the calls to self.qkv.weights, and directly add qkv_bias to self.qkv(x) when possible.
@clementpoiret so concern with this, it changes a kernen that's likely a fused addmm into two separate kernels which isn't great. Aren't there more advanced quanitization options these days that can handle nn.Module an functional calls equally well?
Hmm I see. I can't really find a pretty solution to that point. Actually, F.linear is well quatized (although slower than using nn.Linear in the current implementation).
Given the implementation of QuantWrapper, the only fix that comes to my mind is to call F.linear(self.qkv.module.weights, qkv_bias) if self.qkv is an instance of a QuantWrapper...
But it's including a fix for a very specific use-case in a model definition that shouldn't be more complex than necessary. I'm not really happy with this solution. What do you think about it? Another solution is simply to document the issue to let the user know he have to override the forward pass if he wants to quantize it?
@clementpoiret going to implement as per that linked PR above, uses a attr to determine path to use, but don't think I want to punch that all the way through the model/module init so probably best to do
for n, m in model.named_modules():
if hasattr(m, 'qkv_bias_separate'):
setattr(m, 'qkv_bias_separate', True)
Okay, thanks for the update @rwightman :)