grok-1
grok-1 copied to clipboard
Convert to pytorch model to use transformers from huggingface
Can someone covert this jax model to pytorch model implemented in transformers?
I asked claude3 (it's good but you mileage may vary) how you going to load this thing? it takes many 80gb gpus. https://gist.github.com/johndpope/0aa7b2709bf04c44626c019feb798cfe
Edit: There's a better implementation: https://huggingface.co/keyfan/grok-1-hf https://github.com/LagPixelLOL/grok-1-pytorch
Found this: https://huggingface.co/hpcai-tech/grok-1
"For now, a 8x80G multi-GPU machine is required)."
Is there a way to support distributed training? We do not have that many 80G cards.
Is there a way to support distributed training? We do not have that many 80G cards.
Might be possible, for example, with HF's (huggingface) built in DDP lib's: https://huggingface.co/blog/pytorch-ddp-accelerate-transformers. The original model was likely run as distributed under xai.
There may be other platforms and/or libraries that could be used for distributed computing for neural networks out side of the example option I mentioned with huggingface. It's a lite interest of mine, though I'm not highly experienced with running LLM's in distributed setups.
It seems there is now quantization of Grok-1 that takes it's size down to 120~ GB for file and memory usage. Instead of the 630+ GB. (Not considering additional mem usage for doing forward pass inferences. It's a rough calc of model file size for disk and memory usage before and at time of model loading, just before inference or training use. Useful for getting started.) This can move this into the realm of high end consumer hardware. Which could allow more accessible inference with likely decent enough performance and accuracy out of the model. We've seen this done before in the past with very large models. (Falcon from TII for example.)
- 120~ GB quantization: https://huggingface.co/Arki05/Grok-1-GGUF
- 630+ GB file model sized based off of: https://huggingface.co/hpcai-tech/grok-1/tree/main
As for training grok, it's going to take additional memory for the training process. How much, I wouldn't know. It's something to consider that someone with more experience with training LLMS could likely answer.
Side note, it's possible to train or finetune an already pre-trained model in quantized form.
Overall , it's likely possible that with a quality quantization of the model, could be distributed across multiple compute nodes in quant form for inference and/or training at a resource requirement level that may be more available for you.
Exploration of what and how the model can be run on are still ongoing. Like before, someone will likely figure out how to run it decently on lower end hardware then what it started with. Which could then lead to stabilizing that for real use and running outside of prototype concepts for the rest of us.
@davidearlyoung Thanks for all the details. Do you know whether distributed inference works or not? We have some A100-40G cards and do not like to sacrifice use quantization. We are think whether it's possible to put it on two machines with both TP and PP enabled.
@Jeffwan
Do you know whether distributed inference works or not?
-
I personally do not know if distributed inference works for grok-1 in pytorch.
- I do not have enough experience, or access to hardware that I could verify for myself.
-
In the public AI/ML social areas\circles that I'm part of, I have seen some claims where inference has been done with grok. But I do not know what hardware or how they had their setup arranged to do so.
- For example, keyfan claims to run benchmarks of grok using their converted pytorch form of the open weights from xai's original Jax tensor BF16 form. (see: https://huggingface.co/keyfan/grok-1-hf)
- grok is large in its open weight form for either Jax or pytorch. (318~ GB on file for Jax, 630 ~ GB on file for pytorch.)
- Due to its size at full open weighted forms, it's likely keyfan had to do inference on a distributed setup in order to run their benchmark. *(Maybe reaching out to keyfan might be helpful.)
- For example, keyfan claims to run benchmarks of grok using their converted pytorch form of the open weights from xai's original Jax tensor BF16 form. (see: https://huggingface.co/keyfan/grok-1-hf)
Moving the conversation on. I just want to make sure that I understand you correctly:
We have some A100-40G cards ...
- I understand with what you said is that you have access to A100-40G GPUs. (Likely Nvidia A100's with 40 GB's of vram on each GPU card.)
- And it sounds like you have multiple GPUs available to you.
... and do not like to sacrifice use quantization.
- I think what you are trying to say is that you do not intend to use quantized forms of grok-1 on your A100-40G cards. Let me know if I understood that correctly.
We are think whether it's possible to put it on two machines with both TP and PP enabled.
- Are you trying to say that you and your team think that it is possible to load grok on two machines using Tensor Parallelism and Pipeline Parallelism?
- I'm also assuming that you will have 1 or more A100-40G's on each machine. Let me know if I understood what you were saying there as well.
I have taken a few very brief looks into distributed computing for LLM's over the years. From what I'm understanding (which I could be wrong to certain degrees), if you are wanting to run the full model in GPU for inference with pytorch, in theory, you will need enough vram across your distributed compute system to hold the same size as the model on file. Plus additional vram to deal with model overhead. That additional overhead needed could vary base on a lot of factors. (I think about roughly 20% for basic inference might be enough.)
Here is some quick napkin math:
-
keyfan's or hpcai-tech's pytorch grok FP32 at 630~ GB:
- A100-40G: (630 + (630 * 0.2)) / 40 -> 756 / 40 = 18.9 (round up to 19)
- A100-80G: (630 + (630 * 0.2)) / 80 = 9.45 (round up to 10)
-
Ariki05's grok pytorch Quant Int3 with Important Matrix at ~120 GB:
- A100-40G: (120 + (120 * 0.2)) / 40 -> 144 / 40 = 3.6 (round up to 4)
- A100-80G: (120 + (120 * 0.2)) / 80 = 1.8 (round up to 2)
-
xai's Jax grok BF16 at 318~ GB (I'm not sure since I can't speak for it since I'm not a direct user of jax. But if it's similar to pytorch in memory requirements):
- A100-40G: (318 + (318 * 0.2)) / 40 -> 381.6 / 40 = 9.54 (round up to 10)
- A100-80G: (318 + (318 * 0.2)) / 80 = 4.77 (round up to 5)
Old model memory calculator: https://huggingface.co/spaces/hf-accelerate/model-memory-usage (This is a bit aged. Still might be useful.)
Some multi GPU training and inference info from the perspective of the transformers library:
- https://huggingface.co/docs/transformers/perf_train_gpu_many
- https://huggingface.co/docs/transformers/v4.15.0/parallelism (aged as well. Transformers v4.15.0)
A lot of what I've learned about LLMs basics has been through huggingface. Which is why most of the links I've shared have been through them.
@davidearlyoung I really appreciate your informative analysis! Thanks a lot!
I personally do not know if distributed inference works for grok-1 in pytorch.
Yes! that's my questions to the community as well. I see lots of people do not have that many A100-80G to run this model, what's why I am curious whether multi-node inference work. (definitely need TP or PP or TP+PP)
I think what you are trying to say is that you do not intend to use quantized forms of grok-1 on your A100-40G cards. Let me know if I understood that correctly.
yes
xai's Jax grok BF16 at 318~ GB (I'm not sure since I can't speak for it since I'm not a direct user of jax. But if it's similar to pytorch in memory requirements):
A100-40G: (318 + (318 * 0.2)) / 40 -> 381.6 / 40 = 9.54 (round up to 10) A100-80G: (318 + (318 * 0.2)) / 80 = 4.77 (round up to 5)
This is exactly the situation we are facing. If we do not use quantized forms of grok-1. Let's say we like bf16 version. Then there's no way to put into a machine with 8 * A100-40G. distributed inference in our env is kind of required techniques.
I did some research and notice some frameworks like TensorRT-LLM and vLLM have some support on TP and PP but notice they have some performance limitations. However, I have not tried with grok-1 and just want to get some feedback from the community to see what's the best practice or recommended way to run distributed inference. (probably the question is not valid since those frameworks do not have the support for this model)
Has anyone seeing system OOM I have 1TB system memory with 8GPU's while using PyT model, but it seems to kill the process with OOM kernel message. Jax, one seems to work fine with 1TB of memory, it is using BF16.