[Bounty] PyTorch & HuggingFace Interface
Hello all,
I’ve made some updates to the exo library based on the bounty mentioned in this tweet/X post. These changes aim to integrate PyTorch and expand access to various language models through Hugging Face’s AutoModelForCausalLM.
What's New?
- ShardedHuggingFaceModel: Adds sharding support for Hugging Face models.
- PyTorchDynamicShardInferenceEngine: A new inference engine that uses PyTorch tensors for dynamic sharding.
These updates enable the exo library to use PyTorch, allowing access to a broader range of language models.
Limitations and Bugs
Right now the ShardedHuggingFaceModel is focused on using LlamaForCausalLM from the huggingface transformers library. From that model we break it up using LLamaModel and the layers it contains. We can then select the layers and run the pytorch tensors over them as need. I focused on using llama3.1 8B as I could only slightly run that.
Due to my current hardware limitations (specifically GPU and VRAM), I wasn’t able to fully test this across multiple nodes. The model currently takes about 30 seconds per token to generate for me (I have slow GPUs), which might be related to the absence of caching (not implemented due to VRAM constraints). It’s running without reaching an EOT and the outputs seem random.
Request for Feedback
I’m sharing this in the hope that others can test it on more capable setups and provide feedback on how to enhance performance and stability.
Important Note on Meta LLaMA 3.1 Model
If you plan to test with the official Meta LLaMA 3.1 model, please note:
-
Access: You’ll need to request access and authenticate using
huggingface-clito download it. -
Command: Run the following command before using the model:
I’m exploring ways to simplify this process, but for now, it’s necessary.huggingface-cli login
Chat API Update
- Added an option to select the LLaMA 3.1 model in the chat API.
Looking forward to any feedback or suggestions you might have.
Thank you
Hey, sorry for the delay. I haven't had a chance to check this properly yet. I'll be able to look next week.
Hey, sorry for the delay. I haven't had a chance to check this properly yet. I'll be able to look next week.
Sounds good. Let me know anything needed. Thank you
Great work. You clearly thought about this and implemented a really nice solution. I particularly like the generalisation of model splitting, rather than doing each one separately.
Take a look through the comments I left.
The main thing I want to address and test is device support. We can make this the default inference engine if it works reliably across many devices.
On that point, if we can automate the bootstrapping of the environment for each user (e.g. install drivers, whatever else is needed to run on their device) that would be great. We don't have to do this in this PR/bounty, we can do another. But I would love to discuss and figure out how this can best be done.
I have updated my main fork branch with the pytorch interface changes. Please take a look and test. Thank you!
Hey @risingsunomi I'm thinking of making this the default inference engine on linux machines. Could you resolve conflicts please?
Hey @risingsunomi I'm thinking of making this the default inference engine on linux machines. Could you resolve conflicts please?
Will do and clean up more.
@AlexCheema clean up finished and no conflicts with base branch
torch not added as a dependency
error loading and splitting model: Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`
Error processing prompt: Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`
Traceback (most recent call last):
File "/Users/alex/exo/exo/main.py", line 161, in run_model_cli
await node.process_prompt(shard, prompt, None, request_id=request_id)
File "/Users/alex/exo/exo/orchestration/standard_node.py", line 98, in process_prompt
resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/alex/exo/exo/orchestration/standard_node.py", line 134, in _process_prompt
result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/alex/exo/exo/inference/pytorch/inference.py", line 101, in infer_prompt
await self.ensure_shard(shard)
File "/Users/alex/exo/exo/inference/pytorch/inference.py", line 271, in ensure_shard
self.stateful_sharded_model = ShardedHuggingFaceModel(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/alex/exo/exo/inference/pytorch/model/hf.py", line 58, in __init__
self.llm_model = AutoModelForCausalLM.from_pretrained(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/alex/exo/.venv/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
return model_class.from_pretrained(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/alex/exo/.venv/lib/python3.12/site-packages/transformers/modeling_utils.py", line 3274, in from_pretrained
raise ImportError(
ImportError: Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`
accelerate package needs to be installed
Let me add the dependences to the exo setup.py install_requires
seems to use some other downloader (perhaps transformers?) it should use the exo downloader for integration with exo (also these other downloads aren't necessarily async friendly but the exo one is). also we are going to remove the transformers dependency as soon as possible as it bloats the whole of exo if we want to distribute this as an installable.
seems to use some other downloader (perhaps transformers?) it should use the exo downloader for integration with exo (also these other downloads aren't necessarily async friendly but the exo one is). also we are going to remove the transformers dependency as soon as possible as it bloats the whole of exo if we want to distribute this as an installable.
for this look at the other inference engine implementations for reference. you should first download using the exo downloader then point PyTorch to the directory on disk.
seems to use some other downloader (perhaps transformers?) it should use the exo downloader for integration with exo. also we are going to remove the transformers dependency as soon as possible as it bloats the whole of exo if we want to distribute this as an installable.
It uses the huggingface transformer downloader as the way the transformer modules are initialized, the weights are downloaded when using from_pretransformer
seems to use some other downloader (perhaps transformers?) it should use the exo downloader for integration with exo (also these other downloads aren't necessarily async friendly but the exo one is). also we are going to remove the transformers dependency as soon as possible as it bloats the whole of exo if we want to distribute this as an installable.
for this look at the other inference engine implementations for reference. you should first download using the exo downloader then point PyTorch to the directory on disk.
This would require a rewrite and creating a custom module outside of using transformers. I would need to do more work to remove transformers from this and is possible but will require more time.
It generates! Looks like some tokenizer issue. It never stops generating.
seems to use some other downloader (perhaps transformers?) it should use the exo downloader for integration with exo (also these other downloads aren't necessarily async friendly but the exo one is). also we are going to remove the transformers dependency as soon as possible as it bloats the whole of exo if we want to distribute this as an installable.
for this look at the other inference engine implementations for reference. you should first download using the exo downloader then point PyTorch to the directory on disk.
This would require a rewrite and creating a custom module outside of using transformers. I would need to do more work to remove transformers from this and is possible but will require more time.
That's fine you don't need to get rid of transformers but it should use exo's downloader for downloading the model. Then you can load the model in torch from disk.
seems to use some other downloader (perhaps transformers?) it should use the exo downloader for integration with exo (also these other downloads aren't necessarily async friendly but the exo one is). also we are going to remove the transformers dependency as soon as possible as it bloats the whole of exo if we want to distribute this as an installable.
for this look at the other inference engine implementations for reference. you should first download using the exo downloader then point PyTorch to the directory on disk.
This would require a rewrite and creating a custom module outside of using transformers. I would need to do more work to remove transformers from this and is possible but will require more time.
That's fine you don't need to get rid of transformers but it should use exo's downloader for downloading the model. Then you can load the model in torch from disk.
Ok let me see if I can overload the weight init function and have it use exo's
It generates! Looks like some tokenizer issue. It never stops generating.
Which model is this tested with? Will test more
Another issue (can be fixed last as this is a tricky one). We need to ensure that the torch operations are not blocking operations. This means the blocking parts need to be run on another thread. You can see how it was done for the MLX inference engine here where we use a ThreadPoolExecutor: https://github.com/exo-explore/exo/blob/2b9dec20eb25f8708455e13eabc744d653b7a286/exo/inference/mlx/sharded_inference_engine.py#L28
It generates! Looks like some tokenizer issue. It never stops generating.
Which model is this tested with? Will test more
llama-3.1-8b
This command:
exo --inference-engine pytorch --run-model llama-3.1-8b
Anything I can help with @risingsunomi? I would really like to get this merged ASAP as a lot of people are trying to run on Linux and running into some issues
Anything I can help with @risingsunomi? I would really like to get this merged ASAP as a lot of people are trying to run on Linux and running into some issues
Might need help with testing the larger models but last night I got some of the hf downloader implemented and fixing the eot issue with llama today. Will get through these by the end of the weekend. Sorry for the delay but will hit it more today.
Anything I can help with @risingsunomi? I would really like to get this merged ASAP as a lot of people are trying to run on Linux and running into some issues
Might need help with testing the larger models but last night I got some of the hf downloader implemented and fixing the eot issue with llama today. Will get through these by the end of the weekend. Sorry for the delay but will hit it more today.
that's great. no need to apologise! looking forward to the fixes and will help test larger models on my machines
@risingsunomi have you pushed your latest changes? I still get the issue with torch not being in the setup.py script
@risingsunomi have you pushed your latest changes? I still get the issue with torch not being in the setup.py script
Working on it now
@risingsunomi have you pushed your latest changes? I still get the issue with torch not being in the setup.py script
I have added in the deps for setup.py and tested that the pytorch inference engine does now use the exo downloader but need more test to figure out the llama EOT issue. Moving it to my server to try that out more.
Error on pip install:
ERROR: Could not find a version that satisfies the requirement torch==2.4.0+cu124 (from exo) (from versions: 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1)
ERROR: No matching distribution found for torch==2.4.0+cu124
Error on
pip install:ERROR: Could not find a version that satisfies the requirement torch==2.4.0+cu124 (from exo) (from versions: 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1) ERROR: No matching distribution found for torch==2.4.0+cu124
fix it but we would need to add more to the setup.py that detects the cuda version - working on that too
Getting this output now after
exo --inference-engine pytorch --run-model llama-3.1-8b
seems to use some other downloader (perhaps transformers?) it should use the exo downloader for integration with exo (also these other downloads aren't necessarily async friendly but the exo one is). also we are going to remove the transformers dependency as soon as possible as it bloats the whole of exo if we want to distribute this as an installable.
seems to use some other downloader (perhaps transformers?) it should use the exo downloader for integration with exo. also we are going to remove the transformers dependency as soon as possible as it bloats the whole of exo if we want to distribute this as an installable.
It generates! Looks like some tokenizer issue. It never stops generating.