[RFC] Controller for SPMD+MPMD
[RFC] Controller for SPMD+MPMD
Background
Current work is being done to design a solution for making mark_sharding first trace the model before it is loaded into devices (https://github.com/pytorch/xla/issues/9341). Together with Local SPMD, this should enable us to achieve SPMD+MPMD as per its RFC. One leftover question is which controller to leverage and how to do it. This RFC aims to provide an approach, and two examples of how SPMD+MPMD.
API Discussion
Before thinking about the specifics on the controller, I think it is important to quickly discuss the user interaction experience with SPMD+MPMD. Specifically, how to handle pipeline parallelism in the context of also doing gSPMD. I see two different approaches: (1) to hide some of that process behind a newly created API, or a new level of abstraction; (2) to leverage existing pipeline parallelism tooling.
I think there is a temptation to create something behind a new API to try to simplify the process as much as possible, and create an easy user experience. However, PyTorch already has strong tooling around pipeline parallelism. These tools see external use, and they themselves ease the process of handling multiple processes running different parts of the pipeline.
Rather than creating a new API standard, it is likely better to approach this from a pytorch angle from a “this is a pytorch backend, how do I do pipeline parallelism with pytorch”. Looking at that angle, it is better to support SPMD+MPMD in these pipeline parallelism APIs rather than to create a new API.
Approach
The general approach will be to:
- Trace model without loading it to devices
- Split model into individually executing modules
- Create processes to execute on split modules
- Have modules be executed by process that will be responsible for executing gSPMD
From an implementation perspective, the idea is that by allowing Local SPMD, and latent model initialization, APIs created to specialize on pipeline parallelism should be able to manage their individual processes.
PiPPy
PiPPy is the pipeline parallelism library created by pytorch. It has an overall tool set that might be convenient. For PiPPy, pipeline parallelism usually will usually take:
- Initializing a model without loading it to devices
- Creating a pipe through pipeline
a. At this step, a
GraphModuleis created which contain the modules for each process to execute later - Initializing a process group (
dist.init_process_group) - Creating
PipelineStages based on the pipe - Executing each pipeline stage
You can see a step by step in PiPPy’s read me, or a llama model example here.
Either way, this lets PiPPy to admin individual processes while each process executes gSPMD for the specific modules it was created with.
Ray
Ray is a cluster controller for python that has a lot of utility for scaling large applications, including AI. Ray does not have an explicit pipeline parallelism API, but it can achieve it by leveraging its actors.
- Leverage PiPPy pipeline to create a
GraphModule - Leverage “GraphModule” to identify module splits
- Create Ray actors based on these graph modules
- Launch Ray actors, and wait for them to resolve
Ray will administer the different actor pod while each pod executes gSPMD for the specific modules it was created with.
A tale of two pipeline parallelism approaches
Currently PyTorchXLA does have a pipeline parallelism approach documented in https://github.com/pytorch/xla/tree/r2.7?tab=readme-ov-file. In its existing approach, each device is associated with a process. As the original SPMD+MPMD RFC highlighted, this is a flawed approach as we are unable to apply gSPMD when using pipeline parallelism. The endeavor here to allow gSPMD to run in pipeline parallel through PiPPy, Ray, and other APIs might cause some confusion as a duplication of functionality.
Given that, it is worth noting that after the SPMD+MPMD effort, we should reassess our existing pipeline parallelism methodology, and see if it is possible to deduplicate to the more pytorch approach suggested in the RFC.
@pgmoka Thanks! Are there existing applications with PiPPy using Ray today, and what the benefits would be? The PiPPy approach is fully compatible with our initial approach, so checking in on the benefits we would get from Ray for this feature in particular.
@rpsilva-aws From the RFC, this is more to demonstrate how pipeline parallelism can be achieved through different libraries. Both Ray and PiPPy are valid cases that customers can seek. To a certain extent, I think preference might be dictated by what tool a user is used to. A benchmarking test comparing both would be interesting as an addendum to the SPMD+MPMD feature.
PiPPy has usability features listed in https://github.com/pytorch/PiPPy?tab=readme-ov-file#what-is-pippy. I think it seems to be the easier tool to use out of the box. Ray has pretty extensive debugging tools(https://docs.ray.io/en/latest/ray-observability/index.html) which I can see being quite useful for analyzing the different gSPMD worlds.
Does this answer your question?