torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

refactor TrainPipelineBase to clean input batch after the forward pass

Open TroyGarden opened this issue 1 month ago • 1 comments

Summary:

context

  • previously in the TrainPipelineBase, the cur_batch (model input) is not released until calling the loss.backward().
  • however, the cur_batch is only needed during the forward pass.
  • this diff changes the order of clearing the current batch so that it's cleared right after the forward pass

NOTE: usually the peak memory usage happens at the beginning of the backward pass, so clearing the unused input batch can reduce the peak memory usage.

  • benchmark comparison indicates roughly 1~1.5x of memory saving (input batch ~ 1GB)
name GPU Peak Memory alloc GPU Peak Memory reserved
before 35.94 GB 56.72 GB
after 34.33 GB 54.00 GB
before-inplace 35.94 GB 53.91 GB
after-inplace 34.33 GB 51.35 GB

NOTE: in-place copy batch to gpu won't change the gpu peak memory allocation, but can reduce the peak memory reservation.

Differential Revision: D85483966

TroyGarden avatar Nov 10 '25 01:11 TroyGarden

@TroyGarden has exported this pull request. If you are a Meta employee, you can view the originating Diff in D85483966.

meta-codesync[bot] avatar Nov 10 '25 01:11 meta-codesync[bot]