accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

Update how FSDP uses state dicts for Torch 2.0

Open VikParuchuri opened this issue 2 years ago • 11 comments

Torch 2.0 has made some changes with FSDP and state dicts:

  • full_optim_state_dict is deprecated, and should be replaced with optim_state_dict.
  • scatter_full_optim_state_dict is deprecated, and should be replaced with optim_state_dict_to_load
  • optim_state_dict does not support optim_input
  • set_state_dict_type now takes an optimizer state dict, along with a model state dict

I have updated the FSDP plugin to mirror these changes. I used set_state_dict to request the full state dict as needed. I tried to keep parity with the existing logic in save_model, and only apply these changes for new versions of torch.

This works fine in my limited testing on Torch 2.0.

VikParuchuri avatar Mar 19 '23 00:03 VikParuchuri

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Note, this will not work 100% properly until this PyTorch bugfix PR is merged - https://github.com/pytorch/pytorch/pull/97110

VikParuchuri avatar Mar 19 '23 15:03 VikParuchuri

cc @pacman100

sgugger avatar Mar 20 '23 13:03 sgugger

@pacman100 Do you have any suggestions on torch versioning? There are some conditionals in the FSDP plugin that branch on it.

In this case, my patch here won't work until my patch on PyTorch is also merged / released (probably the next version?). Should I just say if torch <= current version, then use old logic, else new logic?

VikParuchuri avatar Mar 20 '23 16:03 VikParuchuri

@pacman100 The PyTorch PR has now been merged. Let me know how you want me to handle versioning, and I can finalize this.

VikParuchuri avatar Mar 29 '23 20:03 VikParuchuri

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 23 '23 15:04 github-actions[bot]

@pacman100 friendly ping

sgugger avatar Apr 24 '23 14:04 sgugger

Hello @VikParuchuri, I don't see a clear way of implementing the version check for this. I think the best way forward is to wait for PyTorch 2.1 release for merging this PR as the deprecated features will be available till version 2.2.

That way nothing breaks and these changes are reflected from a stable PyTorch release.

Let me know if you have any other thoughts and apologies for the delay.

pacman100 avatar Apr 24 '23 16:04 pacman100

That's fine by me!

VikParuchuri avatar May 03 '23 23:05 VikParuchuri

Hello @VikParuchuri, 2.0.1 version of PyTorch is released. Please add that as the version checking, resolve the quality checks and then we are good to go. Thank you!

pacman100 avatar May 09 '23 14:05 pacman100

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jun 02 '23 15:06 github-actions[bot]

Hello @VikParuchuri, Thank you for all the work. As we couldn't merge this PR then, clubbed these changes and fixed some issues with these in the above PR and added you as a co-author.

pacman100 avatar Jun 13 '23 19:06 pacman100

Thank you!

pacman100 avatar Jun 13 '23 19:06 pacman100