transformers
transformers copied to clipboard
torch.compile() and FSDP/DDP wrappers are called in the wrong order.
System Info
transformers main branch
Who can help?
@sgugger
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
When training/fine-tuning a model, activate torch.compile() and FSDP with torch_compile=True
and fsdp="full_shard auto_wrap"
as training arguments.
The model is compiled before the FSDP wrapping, preventing optimizations on the backwards passes. According to the PyTorch docs, both DDP and FSDP wrappers have special optimizations that run with torch.compile() to ensure model training doesn't end up slower instead of faster (see here).
Expected behavior
Therefore, the model would need to be torch.compile()'d after being wrapped in either FSDP or DDP. Right now, in src/transformers/trainer.py
that is not the case, with compile() being the first call in _wrap_model()
. Before making a PR with the change, I figured I'd make this bug report to ensure nothing prevents that change from happening.
Note that we haven't tested torch.compile
with any kind of distributed training yet, so it's normal if there are issues. If you have the fix, we'd be happy to look at a PR!
Ok! I'll make the PR then, just figured I'd ask before.