transformers icon indicating copy to clipboard operation
transformers copied to clipboard

torch.compile() and FSDP/DDP wrappers are called in the wrong order.

Open ani300 opened this issue 1 year ago • 2 comments

System Info

transformers main branch

Who can help?

@sgugger

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

When training/fine-tuning a model, activate torch.compile() and FSDP with torch_compile=True and fsdp="full_shard auto_wrap" as training arguments.

The model is compiled before the FSDP wrapping, preventing optimizations on the backwards passes. According to the PyTorch docs, both DDP and FSDP wrappers have special optimizations that run with torch.compile() to ensure model training doesn't end up slower instead of faster (see here).

Expected behavior

Therefore, the model would need to be torch.compile()'d after being wrapped in either FSDP or DDP. Right now, in src/transformers/trainer.py that is not the case, with compile() being the first call in _wrap_model(). Before making a PR with the change, I figured I'd make this bug report to ensure nothing prevents that change from happening.

ani300 avatar Mar 16 '23 20:03 ani300

Note that we haven't tested torch.compile with any kind of distributed training yet, so it's normal if there are issues. If you have the fix, we'd be happy to look at a PR!

sgugger avatar Mar 16 '23 20:03 sgugger

Ok! I'll make the PR then, just figured I'd ask before.

ani300 avatar Mar 16 '23 21:03 ani300