accelerate
accelerate copied to clipboard
Update how FSDP uses state dicts for Torch 2.0
Torch 2.0 has made some changes with FSDP and state dicts:
full_optim_state_dictis deprecated, and should be replaced withoptim_state_dict.scatter_full_optim_state_dictis deprecated, and should be replaced withoptim_state_dict_to_loadoptim_state_dictdoes not supportoptim_inputset_state_dict_typenow takes an optimizer state dict, along with a model state dict
I have updated the FSDP plugin to mirror these changes. I used set_state_dict to request the full state dict as needed. I tried to keep parity with the existing logic in save_model, and only apply these changes for new versions of torch.
This works fine in my limited testing on Torch 2.0.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
Note, this will not work 100% properly until this PyTorch bugfix PR is merged - https://github.com/pytorch/pytorch/pull/97110
cc @pacman100
@pacman100 Do you have any suggestions on torch versioning? There are some conditionals in the FSDP plugin that branch on it.
In this case, my patch here won't work until my patch on PyTorch is also merged / released (probably the next version?). Should I just say if torch <= current version, then use old logic, else new logic?
@pacman100 The PyTorch PR has now been merged. Let me know how you want me to handle versioning, and I can finalize this.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@pacman100 friendly ping
Hello @VikParuchuri, I don't see a clear way of implementing the version check for this. I think the best way forward is to wait for PyTorch 2.1 release for merging this PR as the deprecated features will be available till version 2.2.
That way nothing breaks and these changes are reflected from a stable PyTorch release.
Let me know if you have any other thoughts and apologies for the delay.
That's fine by me!
Hello @VikParuchuri, 2.0.1 version of PyTorch is released. Please add that as the version checking, resolve the quality checks and then we are good to go. Thank you!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hello @VikParuchuri, Thank you for all the work. As we couldn't merge this PR then, clubbed these changes and fixed some issues with these in the above PR and added you as a co-author.
Thank you!