Fix FP8 linear layer dimension check to prevent runtime error
Fixes #6390
Problem
When use_fp8=True is enabled in HybridParallelPlugin and the model has output layers with dimensions not divisible by 16 (e.g., binary classification with 2 outputs), the training fails with:
Expected both dimensions of mat2 to be divisible by 16 but got torch.Size([768, 2])
Root Cause
torch._scaled_mm requires both dimensions of the weight matrix to be divisible by 16. The existing check in linear_fp8() only validated:
- Input dimension (
input.shape[-1]) - Batch dimensions (
np.prod(input.shape[:-1]))
But it did not check the output dimension (weight.shape[0]).
When using GPT2ForSequenceClassification with num_labels=2, the score layer has weight shape [768, 2], where 2 is not divisible by 16.
Solution
Added a check for weight.shape[0] % 16 != 0 to fallback to regular F.linear when the output dimension is not compatible with FP8.
if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0 or weight.shape[0] % 16 != 0:
return F.linear(input, weight, bias)
Testing
This fix allows the model to:
- Use FP8 for layers with compatible dimensions (performance benefit)
- Fallback to standard FP16/BF16 for incompatible layers (correctness)
- Run successfully with small output dimensions (e.g., binary classification)
The change is backward compatible and doesn't affect existing working configurations.
Can someone take a look on the PR?
@ryanrussell @gothicx @tiansiyuan @jeffra Can someone take a look in this PR? I am happy to help and contribute to this repo!