torchtitan
torchtitan copied to clipboard
improve reshard_after_forward logic
according to discussions in https://github.com/pytorch/torchtitan/issues/1091
The CI failure is because FSDPMemTracker is not compatible of fully_shard on a list of modules. @sanketpurandare will help address this soon. Let's land it after the feature is available.
@tianyu-l I think it's also acceptable for now to allow the norm to be assigned to the root module. In other words, just wrap tok_embeddings separately and output separately.
Rebase to merge the PR