Enable torch.compile with ZeRO (Experimental)
This PR enables torch.compile with ZeRO stages 1/2/3. You need to add compile section in your DeepSpeed config. The fields in the section are passed to torch.compile.
"compile": {
"disable": false,
"backend": "inductor"
}
To enable a custom backend, you can pass the fully qualified name of the backend function. For example, if you have a backend class my_backend in my_backend.py in the current directory, you can enable it by "backend": "my_backend.my_backend". You can find an example in a unit test.
Currently we validated the results with Megatron-DeepSpeed. See the example for the details.
NOTICE: This PR is a draft. We will need to validate the coverage and accuracy with many more examples.
@stas00, FYI
Amazing work, @tohtana! I'm looking forward to trying it out
Here is a quick feedback:
Could we please flip disable to enabled so that the logic is consistent with other config values?
- no double negation logic
- consistent
enabled(and notenable) - as all other config sections use that name.
tried it out and the compiled engine doesn't seem to forward some (all?) custom methods to the unwrapped model, e.g. it's failing:
[28:7]: File "/data/env/lib/repos/retro-llama/tr043-dawn-llama-3/DeepSpeed/deepspeed/runtime/engine.py", line 468, in __getattr__
[28:7]: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
[28:7]:AttributeError: 'DeepSpeedEngine' object has no attribute 'get_model_tflops_per_batch_per_gpu'
get_model_tflops_per_batch_per_gpu is a normal model's attribute and the same setup works if I set "disable": true for the compile section.
This method is just part of the normal model.
I hacked around it via model.module.method... and then I get many warnings and errors with the inductor backend and then it fails. I have attached the log.
This is just training Llama-2 on a single node using Accelerate with torch-nightly from last night.
The llama model is the same as HF Transformers with some additional methods. https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
If I disable the ds profiler than it runs despite the compilation errors/warnings - same log as in the previous comment, other than the last traceback where it crashes.
I'm also observing a very strange behavior of performance cycling:
the tflops go like this per iteration: 196, 196, 192, 196, 196, 192, 196, 196, 192, - 2 fast one slower - very exactly
w/o compile it was a consistent 194.
so this tells me something gets recompiled every 3 iterations.
@stas00 Thank you for your feedback! This PR is still experimental. Let me address the issues one by one.
The configuration disable is what I specifically sought feedback on. Currently, all configuration items under compile are passed to torch.compile, which accepts disable, not enable. This design was chosen for its simplicity, given the uncertainty of future changes in torch.compile. But we can define enable and flip it before passing it to torch.compile.
Do you have any further comments on this? If not, I will switch it to enable as you suggested. Actually, it is also my personal preference.
That's totally understandable, Masahiro. Tunji made that clear when he tagged me. If it's too early to provide feedback please ping me when you're ready for it.
disable vs enabled:
Ideally, Deepspeed users will never need to know anything about torch.compile specifics - many frameworks integrate this feature w/o having the user interact with it directly. So its API doesn't have to impact Deepspeed's API.
Since most (all?) Deepspeed config sections use enabled I'd say it'd be the most consistent to continue with that convention.
But this is an opinion of a single person, so please seek out opinions of others.
@stas00 Thank you for your quick reply. Probably it is difficult to have a clear conclusion for now. I will simply switch it to enable. Otherwise, many other users would have the same question as yours.
For a clearer answer, we need more experience to know what options DeepSpeed's users need in their applications. Even the options of torch.compile may change.
-
please note that it's
enabledthat DS uses everywhere else and notenable -
wrt other options I'd say - use the minimal amount of options -
-
let's perhaps start with only
backendand then pick the most sensible defaults for that option. -
Then provide a user an API where they can preset their own
**torch_compile_kwargsthat will be passed totorch.compile- that way you're future proofing the Deepspeed API while allowing torch to do what they please - deepspeed will sync with the future changes to keep up with the sensible defaults and power-users should always be able to override the defaults.
deepspeed_engine.set_torch_compile_kwargs(**kwargs)
2a. I don't know if the current config file allows for a not predefined dict, so perhaps this could be possible:
"compile": {
"enabled": true,
"backend": "inductor",
"kwargs": {"key1"=value, "key2"=value}
}
this should definitely work:
"compile": {
"enabled": true,
"backend": "inductor",
"kwargs": "key1=value;key2=value"
}
but I don't know if all torch.compile kwargs could be stringified
but providing a programmatical API for power users would be the most fool-proof:
@stas00, (FYI @tjruwase) Sorry for the delayed response. I have addressed some of the issues you suggested.
- Calling custom functions
- Configuration format in a deepspeed config (You can pass
"kwargs": {"key1"=value, "key2"=value}) - API to set compiler options (
torch_compile_kwargs)
I also added unit tests that verify these features. I am wondering if we should merge this PR after the API design is finalized. I know that it still has a lot of limitations and currently works only for limited model architectures, but we will need to work on the improvements incrementally.
@tohtana, thank you for implementing my suggestions.
I haven't tested the code but looking at your tests this looks good.
I agree that doing this work incrementally is a good idea.
May I recommend to add a note in the API and config docs that this is an experimental API and a subject to change on a moment's notice - so that you're not yet committing to this new API. Let users kick the tires, send feedback and if it feels good then you could commit to that API - here the API is the new config.
wrt future-proofing - perhaps run this proposal by Will and others on slack and see if they feel it resonates or perhaps make suggestions. Surely they can't predict the future but they surely know already what's on the 1-2 year plan so that would help.
as far as API goes, i think it might be better practice to let users apply the torch.compile macro themselves on the model, and have DeepSpeed interact with that. I suggest this for two reasons.
- tunneling args for torch.compile through something else is a little tricky. (another way to handle this might be to let users give you a partial where they already applied their args, and you just store their 'compiler_fn' and call it
- i haven't looked at how you use compile yet, but if you're compiling specific parts of models it might be nice to allow us to eventually compile larger sections after improving torch.compile.
Others may disagree with my suggestions - cc @ezyang @yf225 @wanchaol
Thank you @wconstab for your comment!
The simplest approach would be to allow users to run torch.compile and then pass the compiled model to DeepSpeed. However, this is not feasible because ZeRO3 sets hooks on each layer to gather parameters during initialization.
tunneling args for torch.compile through something else is a little tricky. (another way to handle this might be to let users give you a partial where they already applied their args, and you just store their 'compiler_fn' and call it
The current configuration was designed to integrate the settings into our "DeepSpeed config". The DeepSpeed engine uses it on initialization. The config can be a dictionary, but all items must be serializable since it is often passed as a (path to) JSON. Therefore, I implemented the current approach as the standard procedure for compile configuration in DeepSpeed.
On the other hand, I believe we can still enable users to pass compiler_fn through a new API as another choice as you suggested. I already took a similar approach to pass a custom backend as a function (For consistency, I also allowed giving the package and function name as a string in the config).
i haven't looked at how you use compile yet, but if you're compiling specific parts of models it might be nice to allow us to eventually compile larger sections after improving torch.compile.
We currently compile the entire model provided to the DeepSpeed engine. We avoid compiling some hooks because certain communications are not yet compatible with torch.compile.
I expected users would apply torch.compiler.disable to a module or a function within their model when they need.
ZeRO3 sets hooks on each layer to gather parameters during initialization.
can you say which type of hooks? we support some already, maybe not the ones you need though.
Hi @wconstab, thank you for your response!
DeepSpeed uses register_forward_pre_hook and register_forward_hook, and they are actually working well with torch.compile.
However, DeepSpeed ZeRO3 recursively sets hooks on submodules in the given model. When entering a module, it gathers the parameters and also releases them when exiting the module. I think the behavior is similar to FSDP.
If users give a compiled model, DeepSpeed cannot find when it should gather/release sharded parameters. The best way would be to set the same hooks based on a computation graph that torch.compile produces, but I want to keep changes minimal at this moment.
ok, that makes sense. The alternate approach (maybe for later on) is that you could try to make the contents of the hook 'safe to compile' such that the torch.compile'd graph already contains the release and materialize logic in it. But that is another level of work/design/issues.
@wconstab Thank you for your comment.
Regarding your last comment, currently the zero3 hooks are not safe to compile. I have added torch.compiler.disable to some functions to address this, but I believe we will be able to resolve it in the future. As far as I remember, I encountered an error regarding communication collectives and synchronization using CUDA events (I'm not sure about the latter). We should be able to solve the first issue by replacing the communication collective with functional ones. We can also tweak the synchronization to make it compatible with torch.compile.
As for your first comment, I will allow users to pass compiler_fn as the most flexible option. However, I will still need to maintain the current approach that tunnels compiler options due to the reasons I mentioned above. This is more about compatibility with DeepSpeed than torch.compile.
FYI @wconstab @tjruwase
I addressed @wconstab's comment regarding compiler_fn (CI seems to have some issues right now)
@tohtana, thanks for this amazing PR. Can you please include some tests of cpu and nvme offload?