Support deepspeed dynamo
What does this PR do?
This is a PR that tries to respect https://github.com/microsoft/DeepSpeed/pull/4878 in 🤗 accelerate/transformers.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@pacman100 since it's deepspeed related, and @tohtana since you implemented the deepspeed part.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Hello, overall comment, I get the below error when I run the below test:
pytest -sv tests/deepspeed/test_deepspeed.py -k test_basic_dynamo_run
Thanks! I've committed your dynamo fix, and I'll look at the failed test.
This is ready to be reviewed again :) @pacman100
Without using the env variable TORCHDYNAMO_DEBUG_FUNCTION=forward, I get the following error:
File "/raid/sourab/transformers/src/transformers/models/bert/modeling_bert.py", line 286, in forward
mixed_query_layer = self.query(hidden_states)result = forward_call(*args, **kwargs)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
return self._call_impl(*args, **kwargs)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = inner_convert(frame, cache_entry, hooks, frame_state)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
compiled_product = _compile(
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 665, in _compile
result = forward_call(*args, **kwargs)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
raise InternalTorchDynamoError(str(e)).with_traceback(
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
return F.linear(input, self.weight, self.bias)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
return callback(frame, cache_entry, hooks, frame_state)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
r = func(*args, **kwargs)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 626, in compile_inner
result = inner_convert(frame, cache_entry, hooks, frame_state)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
check_fn = CheckFunctionManager(
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1011, in __init__
compiled_product = _compile(
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 665, in _compile
raise InternalTorchDynamoError(str(e)).with_traceback(guard.create(builder)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_guards.py", line 246, in create
return self.create_fn(builder, self)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 448, in CONSTANT_MATCH
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
val = self.get(guard.name)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 258, in get
r = func(*args, **kwargs)
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 626, in compile_inner
return eval(name, self.scope, CLOSURE_VARS)
File "<string>", line 1, in <module>
check_fn = CheckFunctionManager(
File "/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1011, in __init__
torch._dynamo.exc.InternalTorchDynamoError: type object 'FunctionMeta' has no attribute 'forward'
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
cc @tohtana and @tjruwase in case you have idea about this and the steps to overcome this.
Hi @oraluben, @pacman100, Thank you for your report! Sorry for my late response.
We found that the error is caused by no_grad for the evaluation. Currently we reuse the compiled model even after the grad mode is changed. I thought dynamo automatically recompiles a model when necessary, but it seems that it is not always the case.
I will try to fix this by compiling again when the grad mode is changed.
Hi @oraluben, @pacman100, I found that torch recompiles the model when grad mode is changed. Actually we have the two following issues.
-
Custom linear module DeepSpeed enables its custom linear module when Z3 is enabled. However, it does not work with
torch.compile. So we have disabled the module whentorch.compileis enabled. DeepSpeed disables it inInit()and checks the compile is enabled or not. We expectenabledin the compile config is boolean butautois passed toInit('compile': {'enabled': 'auto', 'backend': 'auto'}). Then DeepSpeed doesn't disable the custom function. It seemsautois set at https://github.com/huggingface/accelerate/blob/a3ce1dffcde40831a28101e32e8e98a0c0fa9d0a/src/accelerate/utils/dataclasses.py#L818 Is this an expected behavior? On the other hand, DeepSpeedEngine receives'compile': {'enabled': True, 'backend': 'inductor'}. Can we pass the same toInit()? -
Z3 hook function Dynamo fails to compile one of functions in Z3 hook. We can exclude the function from the compilation target as in https://github.com/microsoft/DeepSpeed/pull/5325. (We already excluded some other functions for Z3 hook)
I forcibly disable the custom Linear function and disabled the Z3 hook function, and the above example worked. Can you give us a suggestion for the first one? If it works, we can merge https://github.com/microsoft/DeepSpeed/pull/5325 for the second one.
We expect
enabledin the compile config is boolean butautois passed toInit('compile': {'enabled': 'auto', 'backend': 'auto'}).
That sounds like I'm initializing the config in wrong place, can you give some advice about the proper way? @pacman100
On the other hand, I'm submitting this patch in torch: https://github.com/pytorch/pytorch/pull/124273. I think it's safe to land this if the patch goes into torch.
@umchand, FYI
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
not stale, still working on that
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
It seems that compile wrapper is removed in deepspeed 0.14.4.
ref: https://github.com/microsoft/DeepSpeed/pull/5581
Is there any revamp on this PR going on?
@skyshine102 Thanks for reminding! I've created #3069 for that, it's wip for now but should be easy to test.
@oraluben. Thanks for giving another PR. I thought that this old PR was merged and I was searching for line diffs that delete ds_config related code. It turns out that it didn't get merged lol I think the new PR is correct, simply leveraging the args from TorchDynamoPlugin!