torchchat
torchchat copied to clipboard
x86 CPU: BF16 should improve decoding performance relative to FP32 on x86, even without hardware BF16
🚀 The feature, motivation and pitch
As you might expect given that decoding is memory-bandwidth-bound, bf16 is roughly twice as fast as fp32 on my M1 Mac: (python torchchat.py generate llama3.2-1b --device cpu --dtype <as specified>
)
mac, dtype bf16:
Average tokens/sec (total): 16.42
Average tokens/sec (first token): 0.51
Average tokens/sec (next tokens): 19.48
mac, dtype fp32:
Average tokens/sec (total): 4.72
Average tokens/sec (first token): 0.66
Average tokens/sec (next tokens): 4.88
In contrast, using an x86 machine without hardware bf16 support, bf16 does not improve performance:
x86, dtype bf16:
Average tokens/sec (total): 14.44
Average tokens/sec (first token): 7.66
Average tokens/sec (next tokens): 14.50
x86, dtype fp32:
Average tokens/sec (total): 14.00
Average tokens/sec (first token): 7.71
Average tokens/sec (next tokens): 14.10
This matches what you would expect given the documented behavior of MKL's cblas_gemm_bf16bf16f32
, which is to just upconvert to fp32 and call SGEMM.
Alternatives
N/A
Additional context
llama.cpp has a native x86 bfloat16 dot product kernel supporting AVX-512BF16, AVX-512F, AVX2, and scalar: https://github.com/ggerganov/llama.cpp/blob/master/ggml/src/ggml.c#L2149 .
RFC (Optional)
We should be able to do better by taking the same approach as PyTorch's ARM bfloat16 kernel (and llama.cpp's x86 bfloat16 kernel) does in the absence of hardware support: convert each loaded vector of bfloat16s to two vectors of float32s and do the arithmetic at float32 precision. Since decoding is memory-bandwidth bound, we should get a ~2x performance win for decoding.