Implement parallel model preloading
@AlexCheema
Implement Parallel Model Preloading
Description
This PR introduces parallel model preloading to significantly reduce startup times for large models distributed across multiple nodes. By leveraging asyncio, we now preload model shards into memory concurrently, followed by a sequential initialization step.
Changes
- Added
preload_modelmethod to theInferenceEngineabstract class - Implemented
preload_modelinMLXDynamicShardInferenceEngine - Updated
ensure_shardmethod to work with preloaded models - Modified
main.pyto use parallel preloading
Implementation Details
InferenceEnginenow has an abstractpreload_modelmethodMLXDynamicShardInferenceEngine.preload_modelloads model config and weights without full initializationensure_shardcompletes initialization using preloaded data- Main script uses
asyncio.gatherfor parallel preloading
Performance Improvements
- Startup time for multi-shard models is expected to decrease significantly
- Resource utilization during startup is more efficient
How to Test
- Run the main script with a multi-shard model
- Observe logs for parallel preloading and sequential initialization
- Compare startup times with the previous sequential loading approach
Future Work
- Fine-tune the balance between parallel preloading and sequential initialization
- Implement similar optimizations for other inference engines (e.g., TinyGrad)
If you feel like supporting me:
https://buymeacoffee.com/aybanda
Hey, is this AI generated?
We don't accept AI generated PR's.
This doesn't really achieve its intended purpose: calling preload_model in main.py doesn't really make sense since exo doesn't know up front which shards you are going to use.
Hey @AlexCheema I got your point and yes I have generated using AI
Instead of preloading in main.py, we could modify the ensure_shard method to implement a more efficient loading process. Here's a approach that might work better with your design
In MLXDynamicShardInferenceEngine modifying ensure_shard This approach will be more suitable I guess Loads config and weights concurrently Doesn't require changes to main.py or other parts of exo Keeps the loading process within the ensure_shard method, maintaining your existing architecture
If you are interested in this let me know, I will change the code accordingly.