Precision Loss in bf16 Model with float32 Rotation Calculation
Description:
I am experiencing a significant precision drop when using the quarot algorithm on a device limited to float32 calculations. Originally designed for double precision, the rotations are cast to float32. This leads to a substantial drop in the accuracy of the bf16 model, particularly in the pass1 stage which shows a 10% decrease. Interestingly, the pass8 results remain mostly unchanged.
Expected Behavior:
I expected the precision of the bf16 model to be less affected, assuming float32 would provide sufficient stability in calculations despite the reduction from double.
Observed Behavior:
Pass1 accuracy decrease by 10% Pass8 accuracy remains nearly unchanged
Steps to Reproduce:
Apply quarot algorithm with float32 calculations on bf16 model. Observe the precision changes across different passes, notably pass1 and pass8.
Is this behavior typical ? Any solutions?
Thanks @Niko-zyf for your issue.
I am not sure if I got your issue right. I remember that we did a rotation in FP32 and this did not change the results that much. Can you please provide some commands to run and reproduce your results? Also, can you please define Pass1 and Pass8? It's a bit strange that an 8-bit case has higher accuracy than fp16.
@sashkboos Thanks for following up. To clarify:
Precision Observations
Perplexity (PPL) metrics show minimal impact, as expected Critical precision loss manifests in long-context generation (32k tokens), leading to irreversible degradation where the model generates nonsensical outputs Evaluation Methodology
For mathematical tasks: 8 temperature-sampled attempts per prompt Pass8: Probability of ≥1 correct answer in 8 attempts Pass1: Average accuracy across all 8 attempts Original Double-Precision Rationale
My question about the double-precision implementation stems from wanting to understand:
Whether the initial design specifically accounted for long-sequence error accumulation How critical double-precision was for maintaining rotation matrix properties If there's inherent precision headroom in the original implementation that float32 eliminates The core paradox appears to be:
❗️ Standard metrics (PPL) don't reveal the issue
❗️ Generation quality collapses specifically in long-context scenarios
❗️ Pass1/Pass8 divergence suggests precision affects result consistency more than best-case performance
Would appreciate insights into:
Historical design decisions around numerical precision Any known thresholds for error accumulation in rotation operations Potential mitigation strategies for float32 constraints