Missing f8 dtypes
Hi, the unified memory of Apple silicon devices is compelling for AI training, and often enables these to have significantly more memory for parameters and gradients than best consumer or even data center grade discrete GPUs.
However, inspecting the data types list, the smallest float dtype I saw in mlx today was 16 bits (f16 or bf16)
Adding 8 bit floats to mlx would effectively double the maximum possible model size.
Would this be possible with software or does it need to be a hardware update?
If it is possible with software, how could we make it happen?
Some options for sensible default f8 e/m split for an 8-bit dtype could be:
e5m2 e4m3 e3m4
Then the question becomes how we would rank these and decide which one is best for the most likely use cases?
@bionicles I'm not affiliated with mlx, but I think the easiest option would be some software emulation initially, as F8 can be treated like a quantization.
Eventual hardware support would be super cool!
Hi @bionicles ! We do support 3, 4, 6 and 8 bit quantization so you can use large models in MLX already.
Regarding supporting FP8 variants. I think it is quite unlikely that it will be added anytime soon. It is not technically very hard, we are already emulating bf16 for older machines that don't support it natively but I am not sure that it provides any benefit for now. It will be significantly slower than fp32 and it will add several more types that will increase the footprint of the library without being seriously used.
Let me know what you think.
It will be significantly slower than fp32
How do we know fp8 would be slower than fp32? Seems like it would be significantly less calculation involved if there were fewer bits to manipulate, so I don't understand how fp8 could be slower than fp32, can you please clarify the rationale?
Additionally, the main reason to promote fp8 dtype would be to increase the maximum achievable model parameter count. Since the parameters could be smaller, can we fit more of them?
I don't want to run quantized versions of models trained at higher precision, I want to train larger models at reduced precision, because then I can train them on local hardware instead of needing to depend on externals.
@bionicles
How do we know fp8 would be slower than fp32? Seems like it would be significantly less calculation involved if there were fewer bits to manipulate, so I don't understand how fp8 could be slower than fp32, can you please clarify the rationale?
Arithmetic Logic Units (ALUs) don't process data one bit at a time; they handle multiple bits simultaneously. However, modern processors are optimized for 32-bit/ 64-bit operations. Processing 8-bit data requires additional steps, such as grouping bits into 32-bit words, which can introduce overhead and reduce performance. Therefore, while fp8 reduces data size, the lack of native support for 8-bit operations in most hardware can lead to slower processing compared to fp32.
My bad for the snarky terse comment.
Even if you're right, what if the name of the game is max parameter counts? Having imprecise parameters also makes it easier to find optimal ones because there are fewer options to try. Furthermore, as those docs indicate, can't SIMD operations combine many f8 values into a larger word, so the overheads you mention don't make as big of a difference?
Really, we're often memory constrained, even if it's slower, isn't it potentially still better to be able to train larger models slowly than go oom or be forced to train smaller ones and then quantize them? What about test time training or EBMs?
Check out this segment of Jeff Dean's interview with Dwarkesh:
The first thing he really dives into is this idea of reduced precision. I just think it's worth having that option to train at reduced precision in MLX. The yield on time investment could be big, right? WDYT?
Yeah the Apple Neural Engine actually uses 8bit internally (doubt it's float though). The Apple GPU in metal apparently runs most ops as FP16 (or BF16) so that's actually peak throughput.
Really to fit these models in, we have to either quantize everything, run a massively reduced parameter count, or selectively quantize some layers lower. The more knobs we can turn, the better.
(You can look at CoreML ops it is heavily geared towards 8 bit)
Apologies for butting in, but I stumbled across this attempting to convert an FP8 checkpoint to an mlx model. I could just cast the fp8 tensors to f32, but am I right in thinking that If I do that followed by a 8bit quant I would lose a fairly significant amount of data from the original model? And if that's correct would native support of fp8 allow me to convert to a mlx model without quantization but would hit a similar sort of size? I'm hitting the limits of what my M2 Pro can store in it's 16GB so if I could eek out some performance without having to upgrade then I'm all for it.
I could just cast the fp8 tensors to f32, but am I right in thinking that If I do that followed by a 8bit quant I would lose a fairly significant amount of data from the original model?
Not necessarily. The conversion from fp8 -> fp16 is lossless since any fp8 number is representable by fp16 (or bf16 or fp32).
The quantization from fp16 -> q8 tends to be minimally lossy. In my experience q8 models score the same on evals as fp16 / bf16.
The quantization from fp16 -> q8 tends to be minimally lossy. In my experience q8 models score the same on evals as fp16 / bf16.
Do you think q8 activations could be interesting for performance - maybe even integration into the sdpa functions?
Do you think q8 activations could be interesting for performance - maybe even integration into the sdpa functions?
With KV cache quantization absolutely! We already have that in mlx-lm in fact. But it's not a fused quantized SDPA which could in theory be even faster. There is a stale PR in MLX for exactly that though.
With KV cache quantization absolutely! We already have that in mlx-lm in fact. But it's not a fused quantized SDPA which could in theory be even faster.
I was wondering what you thought about keeping the activations quantized, though. For instance, a quantized rms norm, gelu, and so on. Not sure if these ops using quantized inputs would give any benefit, but perhaps for memory boundness in decode?
I was wondering what you thought about keeping the activations quantized, though. For instance, a quantized rms norm, gelu, and so on. Not sure if these ops using quantized inputs would give any benefit, but perhaps for memory boundness in decode?
I don't think it would be useful and possibly make things worse. Quantization helps most with the memory bandwidth bound parts of the computation. So for all intents and purposes anything that looks like a matrix-vector multiplication. For other stuff it is just precision loss at best and also likely result in a slowdown.
Isn't this useful for the CUDA backend? Apart from that there are some providers releasing fp8 models only (Kimi K2), so having a native support to fp8 will help during quantization. But there are alternative solutions.
Isn't this useful for the CUDA backend?
Yes very much so. We will likely add fp8.. just waiting for the right time. For Apple silicon it's still not that useful. For fp8 models like DSV3 or Kimi you can already load and quantize them in their native format.
MLX will do on-the-fly conversion to bf16 followed by quantization without ever materializing the full bf16 model.
For the CUDA back-end it's way more useful but we have higher priority items in progress there.
@awni waiting for M5 an fp8 ;)
@awni sorry to tag you again as M5 is official, can you confirm tensor cores and fp8?