QuaRot icon indicating copy to clipboard operation
QuaRot copied to clipboard

Precision Loss in bf16 Model with float32 Rotation Calculation

Open Niko-zyf opened this issue 8 months ago • 2 comments

Description:

I am experiencing a significant precision drop when using the quarot algorithm on a device limited to float32 calculations. Originally designed for double precision, the rotations are cast to float32. This leads to a substantial drop in the accuracy of the bf16 model, particularly in the pass1 stage which shows a 10% decrease. Interestingly, the pass8 results remain mostly unchanged.

Expected Behavior:

I expected the precision of the bf16 model to be less affected, assuming float32 would provide sufficient stability in calculations despite the reduction from double.

Observed Behavior:

Pass1 accuracy decrease by 10% Pass8 accuracy remains nearly unchanged

Steps to Reproduce:

Apply quarot algorithm with float32 calculations on bf16 model. Observe the precision changes across different passes, notably pass1 and pass8.

Is this behavior typical ? Any solutions?

Niko-zyf avatar Apr 15 '25 04:04 Niko-zyf

Thanks @Niko-zyf for your issue.

I am not sure if I got your issue right. I remember that we did a rotation in FP32 and this did not change the results that much. Can you please provide some commands to run and reproduce your results? Also, can you please define Pass1 and Pass8? It's a bit strange that an 8-bit case has higher accuracy than fp16.

sashkboos avatar Apr 16 '25 08:04 sashkboos

@sashkboos Thanks for following up. To clarify:

Precision Observations

Perplexity (PPL) metrics show minimal impact, as expected Critical precision loss manifests in long-context generation (32k tokens), leading to irreversible degradation where the model generates nonsensical outputs Evaluation Methodology

For mathematical tasks: 8 temperature-sampled attempts per prompt Pass8: Probability of ≥1 correct answer in 8 attempts Pass1: Average accuracy across all 8 attempts Original Double-Precision Rationale

My question about the double-precision implementation stems from wanting to understand:

Whether the initial design specifically accounted for long-sequence error accumulation How critical double-precision was for maintaining rotation matrix properties If there's inherent precision headroom in the original implementation that float32 eliminates The core paradox appears to be:

❗️ Standard metrics (PPL) don't reveal the issue

❗️ Generation quality collapses specifically in long-context scenarios

❗️ Pass1/Pass8 divergence suggests precision affects result consistency more than best-case performance

Would appreciate insights into:

Historical design decisions around numerical precision Any known thresholds for error accumulation in rotation operations Potential mitigation strategies for float32 constraints

Niko-zyf avatar Apr 16 '25 08:04 Niko-zyf