ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

Fix FP8 linear layer dimension check to prevent runtime error

Open ssam18 opened this issue 2 months ago • 2 comments

Fixes #6390

Problem

When use_fp8=True is enabled in HybridParallelPlugin and the model has output layers with dimensions not divisible by 16 (e.g., binary classification with 2 outputs), the training fails with:

Expected both dimensions of mat2 to be divisible by 16 but got torch.Size([768, 2])

Root Cause

torch._scaled_mm requires both dimensions of the weight matrix to be divisible by 16. The existing check in linear_fp8() only validated:

  • Input dimension (input.shape[-1])
  • Batch dimensions (np.prod(input.shape[:-1]))

But it did not check the output dimension (weight.shape[0]).

When using GPT2ForSequenceClassification with num_labels=2, the score layer has weight shape [768, 2], where 2 is not divisible by 16.

Solution

Added a check for weight.shape[0] % 16 != 0 to fallback to regular F.linear when the output dimension is not compatible with FP8.

if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0 or weight.shape[0] % 16 != 0:
    return F.linear(input, weight, bias)

Testing

This fix allows the model to:

  • Use FP8 for layers with compatible dimensions (performance benefit)
  • Fallback to standard FP16/BF16 for incompatible layers (correctness)
  • Run successfully with small output dimensions (e.g., binary classification)

The change is backward compatible and doesn't affect existing working configurations.

ssam18 avatar Nov 12 '25 17:11 ssam18

Can someone take a look on the PR?

ssam18 avatar Nov 30 '25 14:11 ssam18

@ryanrussell @gothicx @tiansiyuan @jeffra Can someone take a look in this PR? I am happy to help and contribute to this repo!

SamareshSingh avatar Dec 01 '25 20:12 SamareshSingh