torchrec
torchrec copied to clipboard
refactor TrainPipelineBase to clean input batch after the forward pass
Summary:
context
- previously in the TrainPipelineBase, the
cur_batch(model input) is not released until calling theloss.backward(). - however, the
cur_batchis only needed during the forward pass. - this diff changes the order of clearing the current batch so that it's cleared right after the forward pass
NOTE: usually the peak memory usage happens at the beginning of the backward pass, so clearing the unused input batch can reduce the peak memory usage.
- benchmark comparison indicates roughly 1~1.5x of memory saving (input batch ~ 1GB)
| name | GPU Peak Memory alloc | GPU Peak Memory reserved |
|---|---|---|
| before | 35.94 GB | 56.72 GB |
| after | 34.33 GB | 54.00 GB |
| before-inplace | 35.94 GB | 53.91 GB |
| after-inplace | 34.33 GB | 51.35 GB |
NOTE: in-place copy batch to gpu won't change the gpu peak memory allocation, but can reduce the peak memory reservation.
Differential Revision: D85483966
@TroyGarden has exported this pull request. If you are a Meta employee, you can view the originating Diff in D85483966.