composer
composer copied to clipboard
Enable Elastic Sharded Checkpointing
What does this PR do?
Before submitting
Save Local
- [x] implemented
- [x] approved by Daniel or Mihir
Save Remote
- [x] implemented
- [x] approved by Daniel or Mihir
Load Local (remember to do fix torchmetrics + torch 2.0 sharded)
- [x] implemented
- [x] approved by Daniel or Mihir
Load Remote
- [x] implemented
- [x] approved by Daniel or Mihir
Double Check Torch 1.13 Works
- [x] done
Add Elasticity Test
- [x] implemented
- [ ] approved by Daniel or Mihir or Nikhil
Save/Load local symlinks
- [x] implemented
- [ ] approved by Daniel or Mihir
Single Node Manual Tests
- [x] done
Multinode Manual Tests
- [ ] done
Save/Load Remote Symlinks
- [x] implemented
- [ ] approved by Daniel or Mihir
Clean Up checkpoints
- [x] implemented
- [ ] approved by Daniel or Mihir
Support for legacy sharded files
- [x] implemented
- [ ] approved by Daniel or Mihir
What issue(s) does this change relate to?
fix CO-1953
Save Local
Finished.
Ok, I changed it so only optimizer gets bumped to the top. Can either @dakinggg or @mvpatel2000 click the approved by Daniel or Mihir box?
ok, @mvpatel2000, @dakinggg, I'm ready for the next partial PR review. This is for save remote. Most of the changes are here
Thanks y'all 🥰
Save Remote
done. thanks, @dakinggg and @mvpatel2000
alright @dakinggg, @mvpatel2000 , I got local load working. give it a look when you can! it's here
Load Local
Done. Thanks @mvpatel2000 and @dakinggg !
Ok, @dakinggg , @mvpatel2000 , load from remote is live. PTAL. it's here
Load Remote
Ok this part is done! Thanks, @dakinggg ! and Mihir
Add Elasticity Test
Done
Local Symlink Support
Done
Remote Symlink/Autoresume Support
done
Legacy Sharded Support
Done
Clean up local checkpoints
done
EDIT: sorry, just saw the manual tests in the pr description. Do we have any tests around the backwards compat/errors when it won't work?
What do you mean? Like loading old checkpoints? there is a load_old_checkpoint test
Please add a full PR description,
Done
I mostly meant, for that one thing where you forgot the parenthesis for the functions, did we have a test that would've caught that issue?
I mostly meant, for that one thing where you forgot the parenthesis for the functions, did we have a test that would've caught that issue?
Confirmed. Got this error when running multinode load on torch 2.0 once I added the parentheses:
│ load_sharded_checkpoint │
│ │
│ 314 │ │
│ 315 │ using_multinode = dist.get_world_size() != dist.get_local_world_si │
│ 316 │ if not using_torch_2_0_1() and using_multinode: │
│ ❱ 317 │ │ raise ValueError( │
│ 318 │ │ │ f'Sharded checkpoint loading on >1 node requires torch ver │
│ 319 │ │ ) │
│ 320 │
╰──────────────────────────────────────────────────────────────────────────────╯
ValueError: Sharded checkpoint loading on >1 node requires torch version >=
2.0.1. You have torch version 2.0.0+cu117