torchchat icon indicating copy to clipboard operation
torchchat copied to clipboard

x86 CPU: BF16 should improve decoding performance relative to FP32 on x86, even without hardware BF16

Open swolchok opened this issue 4 months ago • 3 comments

🚀 The feature, motivation and pitch

As you might expect given that decoding is memory-bandwidth-bound, bf16 is roughly twice as fast as fp32 on my M1 Mac: (python torchchat.py generate llama3.2-1b --device cpu --dtype <as specified>)

mac, dtype bf16:
      Average tokens/sec (total): 16.42
Average tokens/sec (first token): 0.51
Average tokens/sec (next tokens): 19.48

mac, dtype fp32:
      Average tokens/sec (total): 4.72
Average tokens/sec (first token): 0.66
Average tokens/sec (next tokens): 4.88

In contrast, using an x86 machine without hardware bf16 support, bf16 does not improve performance:

x86, dtype bf16:
      Average tokens/sec (total): 14.44
Average tokens/sec (first token): 7.66
Average tokens/sec (next tokens): 14.50

x86, dtype fp32:
      Average tokens/sec (total): 14.00
Average tokens/sec (first token): 7.71
Average tokens/sec (next tokens): 14.10

This matches what you would expect given the documented behavior of MKL's cblas_gemm_bf16bf16f32, which is to just upconvert to fp32 and call SGEMM.

Alternatives

N/A

Additional context

llama.cpp has a native x86 bfloat16 dot product kernel supporting AVX-512BF16, AVX-512F, AVX2, and scalar: https://github.com/ggerganov/llama.cpp/blob/master/ggml/src/ggml.c#L2149 .

RFC (Optional)

We should be able to do better by taking the same approach as PyTorch's ARM bfloat16 kernel (and llama.cpp's x86 bfloat16 kernel) does in the absence of hardware support: convert each loaded vector of bfloat16s to two vectors of float32s and do the arithmetic at float32 precision. Since decoding is memory-bandwidth bound, we should get a ~2x performance win for decoding.

swolchok avatar Oct 02 '24 00:10 swolchok