composer icon indicating copy to clipboard operation
composer copied to clipboard

Enable Elastic Sharded Checkpointing

Open eracah opened this issue 2 years ago • 12 comments
trafficstars

What does this PR do?

Before submitting

Save Local

  • [x] implemented
  • [x] approved by Daniel or Mihir

Save Remote

  • [x] implemented
  • [x] approved by Daniel or Mihir

Load Local (remember to do fix torchmetrics + torch 2.0 sharded)

  • [x] implemented
  • [x] approved by Daniel or Mihir

Load Remote

  • [x] implemented
  • [x] approved by Daniel or Mihir

Double Check Torch 1.13 Works

  • [x] done

Add Elasticity Test

  • [x] implemented
  • [ ] approved by Daniel or Mihir or Nikhil

Save/Load local symlinks

  • [x] implemented
  • [ ] approved by Daniel or Mihir

Single Node Manual Tests

  • [x] done

Multinode Manual Tests

  • [ ] done

Save/Load Remote Symlinks

  • [x] implemented
  • [ ] approved by Daniel or Mihir

Clean Up checkpoints

  • [x] implemented
  • [ ] approved by Daniel or Mihir

Support for legacy sharded files

  • [x] implemented
  • [ ] approved by Daniel or Mihir

What issue(s) does this change relate to?

fix CO-1953

eracah avatar May 31 '23 01:05 eracah

Save Local

Finished. Ok, I changed it so only optimizer gets bumped to the top. Can either @dakinggg or @mvpatel2000 click the approved by Daniel or Mihir box?

eracah avatar May 31 '23 21:05 eracah

ok, @mvpatel2000, @dakinggg, I'm ready for the next partial PR review. This is for save remote. Most of the changes are here

Thanks y'all 🥰

eracah avatar Jun 01 '23 21:06 eracah

Save Remote

done. thanks, @dakinggg and @mvpatel2000

eracah avatar Jun 06 '23 00:06 eracah

alright @dakinggg, @mvpatel2000 , I got local load working. give it a look when you can! it's here

eracah avatar Jun 07 '23 00:06 eracah

Load Local

Done. Thanks @mvpatel2000 and @dakinggg !

eracah avatar Jun 07 '23 20:06 eracah

Ok, @dakinggg , @mvpatel2000 , load from remote is live. PTAL. it's here

eracah avatar Jun 07 '23 23:06 eracah

Load Remote

Ok this part is done! Thanks, @dakinggg ! and Mihir

eracah avatar Jun 09 '23 17:06 eracah

Add Elasticity Test

Done

eracah avatar Jun 09 '23 23:06 eracah

Local Symlink Support

Done

eracah avatar Jun 15 '23 17:06 eracah

Remote Symlink/Autoresume Support

done

eracah avatar Jun 15 '23 20:06 eracah

Legacy Sharded Support

Done

eracah avatar Jun 16 '23 19:06 eracah

Clean up local checkpoints

done

eracah avatar Jun 20 '23 23:06 eracah

EDIT: sorry, just saw the manual tests in the pr description. Do we have any tests around the backwards compat/errors when it won't work?

What do you mean? Like loading old checkpoints? there is a load_old_checkpoint test

eracah avatar Jul 27 '23 01:07 eracah

Please add a full PR description,

Done

eracah avatar Jul 27 '23 01:07 eracah

I mostly meant, for that one thing where you forgot the parenthesis for the functions, did we have a test that would've caught that issue?

dakinggg avatar Jul 27 '23 04:07 dakinggg

I mostly meant, for that one thing where you forgot the parenthesis for the functions, did we have a test that would've caught that issue?

Confirmed. Got this error when running multinode load on torch 2.0 once I added the parentheses:

│ load_sharded_checkpoint                                                      │
│                                                                              │
│   314 │                                                                      │
│   315 │   using_multinode = dist.get_world_size() != dist.get_local_world_si │
│   316 │   if not using_torch_2_0_1() and using_multinode:                    │
│ ❱ 317 │   │   raise ValueError(                                              │
│   318 │   │   │   f'Sharded checkpoint loading on >1 node requires torch ver │
│   319 │   │   )                                                              │
│   320                                                                        │
╰──────────────────────────────────────────────────────────────────────────────╯
ValueError: Sharded checkpoint loading on >1 node requires torch version >= 
2.0.1. You have torch version 2.0.0+cu117

eracah avatar Jul 28 '23 03:07 eracah