I think that the CGA is not applied on `w_q @ w_k` when using StatsQ and CGA..
Hi, I have a quick question about the StatsQuantizer_specific_4_qkreparam_cga implementation. https://github.com/nbasyl/OFQ/blob/7ed37d1dd33d39395edbf49fcbbc52f678ecf961/src/quantization/quantizer/statsq.py#L191C1-L192C1
The final line of the forward pass is: quan_weights = quan_weights_no_grad.detach() - real_weights.detach() + real_weights
but, (in if self.training) b4_round -> af_round -> quant_weights_no_grad -> detach()! so, b4_round's gradient is not exist .. that means CGA is computed, but not applied on backward pass. backward gradiente is STE + real_weight.
This appears to implement a Straight-Through Estimator (STE) for the entire block. Doesn't this mean the CGA logic from earlier in the function is effectively bypassed and not applied to the Wq@Wk tensor's gradient?
Thanks for your clarification.