zeus
zeus copied to clipboard
Subclass `torch.distributed.pipelining.PipelineStage` for PFO
PyTorch 2.4 has a new API for pipeline parallelism, which includes PipelineStage. With this, we can subclass PipelineStage and override forward_one_chunk and backward_one_chunk, where each will first set the GPU's frequency using the async frequency controller and run actual forward/backward.
In case users already have an instance of PipelineStage (manual splitting) or _PipelineStage (automatic splitting with pipeline), we can provide a static method on our PipelineStage subclass that melts the user's pipeline stage into ours.
POC can be done on TorchTitan's train.py without having to modify TorchTitan.