mlx-examples
mlx-examples copied to clipboard
Keep `dtype` of Models
While playing around with the TinyLlama and Llama examples I noticed that weights are always cast to float16, regardless of the format weights are loaded:
- TinyLlama weights are stored in
float32 - Llama weights are stored in
bfloat16
I'd suggest to keep those dtypes in convert.py.
For float32 this should be straight forward. Handling bfloat16 is a bit more challenging.
- np does not support
bfloat16. Sincenp.savezis used, that would need to be changed. Fortunately, mx hassavezimplemented and it works forbfloat16, so I'd suggest to use that one instead. - converting from torch is currently done with
v = v.to(torch.float16).numpy(). To convert to mx without precision loss, we would need to change this tov = mx.array(v.to(torch.float32).numpy(), dtype=mx.bfloat16)to avoid any intermediate rounding. (Maybe there is a better method without np as intermediate data format.)
So I am proposing to move from np.savez to mx.savez in all examples, where applicable and try to keep the original dtype of the models unless explicit conversion like in quantization is requested.
A quick try out shows that computation in bloat16 does actually generate different tokens (llama-2-13b-chat):
In the beginning the Universe was created.
This has made a great deal of people very angry and been widely regarded as a bad move.
-- Douglas Noble, "The Book of Infinite wisdom"
< I think this is a great example of a joke that is both funny and thought-proving. It takes a familiar concept (the beginning of the Univere) and subverts it in a way that is both surprising and insightful. The puncutation and wording also add to the
---
> I think this is a great example of a joke that is both funny and thought-proving. It takes a well-known concept (the beginning of the Univere) and adds a new and absurd spin to it, creating a humorously absurd situation. At the same time, it also
As expected, changing precision has more impact on tokens generated later.
I am happy to create a PR if there is positive feedback to the proposal.
So I am proposing to move from np.savez to mx.savez in all examples, where applicable and try to keep the original dtype of the models unless explicit conversion like in quantization is requested.
I think that is a good idea, let's us save bfloat which I like
To convert to mx without precision loss, we would need to change this to v = mx.array(v.to(torch.float32).numpy()) to avoid any intermediate rounding. (Maybe there is a better method without np as intermediate data format.)
My only concern here is we have to be careful not to blow up memory for larger models. So we could do something like?
weights = {}
for k, v in state_dict.items():
weights[k] = mx.array(v.to(torch.float32).numpy())
del v
Are you interested in making this change @dastrobu ? (Our llm example surface area is getting large, so we could also wait on this until we do a bit of consolidation... ).
@dastrobu Not following; it looks like you are converting the bfloat16 weights to float32, and saving them as float32. At what point do you convert them back to bfloat16?
@awni is there a direct way of preserving the bfloat16 format?
FWIW, my generation quality is pretty poor compared to the reference implementation of Mistral Instruct v0.2, which is loaded and run in bfloat16. The answers from the reference implementation are perfect, I'm not seeing the same quality output from MLX. It may be do to the type-conversion, but I'm just not sure :(
@dastrobu Not following; it looks like you are converting the bfloat16 weights to float32, and saving them as float32. At what point do you convert them back to bfloat16?
Good point, I was implicitly assuming we either quantize or convert to a lower precision dtype.
FWIW, my generation quality is pretty poor compared to the reference implementation of Mistral Instruct v0.2, which is loaded and run in bfloat16. The answers from the reference implementation are perfect, I'm not seeing the same quality output from MLX. It may be do to the type-conversion, but I'm just not sure :(
@vgoklani I think I know the issue there. This line needs to be added to the config.json in the mlx community Hugging Face repo. Also make sure you are using the latest mlx-examples which properly reads that.
Edit I added the line so if you pull the repo it should have the rope_theta set.
thank you @awni
Not following; it looks like you are converting the bfloat16 weights to float32, and saving them as float32. At what point do you convert them back to bfloat16?
@vgoklani you are totally right, because I messed up the code snippet, my apologies. What I meant was:
v = mx.array(v.to(torch.float32).numpy(), dtype=mx.bfloat16) instead of v = mx.array(v.to(torch.float32).numpy()). (Issue updated with correct snippet).
Good point, I was implicitly assuming we either quantize or convert to a lower precision dtype.
@awni I think this should be decided by the user. By default, we should simply use the dtype of the model creators without any rounding to best reproduce the results of the model using other frameworks.
My only concern here is we have to be careful not to blow up memory for larger models. So we could do something like?
@awni I share your concern. But I would only have one tensor at the time stored in float32, even with the current implementation:
for k, v in state.items():
v = mx.array(v.to(torch.float32).numpy(), mx.bfloat16)
Your idea to remove the converted weights from the state dict is another optimization, which sounds like a good idea. (though deleting the v won't free the memory). But we can do: state[k] = None.
Are you interested in making this change @dastrobu ? (Our llm example surface area is getting large, so we could also wait on this until we do a bit of consolidation... ).
@awni yes I am. My only concern is that I am not sure I find the time to download all the models and run a test after changing the code. (Having a CI (GitHub workflows or circle CI would really be helpful for this repo).
I'd suggest creating a draft PR for the llama example. Then we can discuss the changes on one specific example. Then we can discuss how to roll out the changes to all the other examples. Let me know what you think.
All great points! Thanks! I'll take a look at your PR for this