llm-foundry
                                
                                 llm-foundry copied to clipboard
                                
                                    llm-foundry copied to clipboard
                            
                            
                            
                        Add `device_map` support for `hf_generate.py` and `hf_chat.py`
This PR enables --device_map auto which enables using these scripts with very large models that don't fit on a single GPU.
This also removes the need for FSDP support @alextrott16 !
How is it parallelizing the model?
How is it parallelizing the model?
@dskhudia the HF device_map feature is very simple, it just stores the model weights on different devices (and it respects a model attribute called _no_split_modules that the model author writes) and then it move activations from one GPU to the next during forward. Basically pipelining with one microbatch. So there is no speedup, but you have more GPU memory now to host the model weights.
HF docs: https://huggingface.co/docs/accelerate/usage_guides/big_modeling
I tried refactoring the defaults so that it defaults to device_map=auto (which leads to much faster CPU init than leaving it out). I think this is better UX but this requires that we have accelerate installed.
To satisfy this without users running into an error and having to pip install acclerate every time they use our scripts, I moved accelerate to be a LLM Foundry dependency, which I don't like... the alternative would be to separate the inference/ folder deps into something like pip install .[inference] which would include accelerate, onnx, onnxruntime, etc.
WDYT about the two options @alextrott16 ? Or something else?
WRT the question of whether to add accelerate in the base dependencies, I don't have a particular problem with doing so. I think it just invites confusion to have a bunch of installation tags (or whatever they're called) depending on what scripts you want to use.
Side note: we can revert our init_empty_weights implementation (for doing meta initialization on HF models) with the accelerate version if we're installing it.
Just chiming in that I am already using this branch for testing the chat model!
OK, will just go with adding accelerate as a dependency for now.