Support bfloat16 for quantization convert
Currently, we convert the weight to float16 during quantization. However, since we have done significant performance improvements with bfloat16 quantization, I am wondering if we can also support for bfloat16 during quantization.
https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/convert.py#L102
https://github.com/ml-explore/mlx/pull/663
I'm not opposed to supporting bfloat but I would not want to make it the default:
- float16 is still considerably faster given it has native support. The benchmarks in #633 don't tell the whole story. I see 42 TPS with fp16 vs 32 with bfloat16.
- the precision loss from bfloat16 -> float16 is nothing compared to float16 -> 4-bit quantization. So I think the impact of using bfloat is minor from a precision standpoint. Although maybe you have some case in mind that would be easier in bfloat16?
I have tried the build mlx from mlx last master branch, but I didn't notice significant performance improvement as stated in #663. I was thinking that it might be because we quantized the model in float16. If that's not the case, I am happy to close this issue.
I have tried the build mlx from mlx last master branch, but I didn't notice significant performance improvement as stated in #663. I was thinking that it might be because we quantized the model in float16. If that's not the case, I am happy to close this issue.
Not sure I follow your comment. But just in case, #663 really only sped up bfloat16 quantized models of which there are not very many. I believe this one is one of the few examples. There should be no change for float16 and float32 quantized models.
I have tried the build mlx from mlx last master branch, but I didn't notice significant performance improvement as stated in #663. I was thinking that it might be because we quantized the model in float16. If that's not the case, I am happy to close this issue.
Not sure I follow your comment. But just in case, #663 really only sped up bfloat16 quantized models of which there are not very many. I believe this one is one of the few examples. There should be no change for float16 and float32 quantized models.
Okay, I will try to quant a bfloat16 model to see the performance difference. Sorry, my previous comments were saying that I tried converting models from the current convert.py but didn't see much improvement and noticed that we are doing float16 quantization. So, I raised this to see if we can enable bfloat16 quantization.
@awni Somehow I am getting an error for the bfloat16 quant model (m2 ultra) from mlx's master build:
libc++abi: terminating due to uncaught exception of type std::runtime_error: [metal::Device] Unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
Can you share the command you ran? I didn't realize there was a scan in the LLM code..
It does look like our bloat16 scans are commented.. not sure why that is.
Can you share the command you ran? I didn't realize there was a scan in the LLM code..
python -m mlx_lm.server --model <path_to_bfloat16_quant_model>
happened to mlx-community/Mistral-7B-Instruct-v0.2-4bit-mlx as well, double checked only happend on mlx_lm.server and mlx_lm.generate works fine.
Oh I see, that is the cumsum from the topk sampling. That will be an issue for bfloat, one workaround until we figure out why there is no scan for bfloat (and possibly implement it) is to cast the logits to float32 before the sampling step if they are bfloat16.
@angeloskath or @jagrit06 do you know why bfloat is not supported in the scans? Is that because there are no simd reductions for bfloat?
@angeloskath might have the best insight for scans in particular but I think it came down to exclusive and inclusive reduction primitives for the scans that weren't there for bfloat - I'm not sure if we ended up making the same workarounds available for it as we did for other sims reductions