[RFC] MPMD+SPMD Pipeline Parallelism
🚀 Feature
We propose an accelerator-agnostic, hybrid Single-Program Multiple-Data (SPMD)/Multiple-Program Multiple-Data (MPMD) Pipeline Parallelism implementation in PyTorch XLA. The key objectives are:
- Enable efficient model-parallel training for large language models, retaining the same SPMD-model semantics within each pipeline stage, allowing users to write as if in a single large device.
- Enable heterogeneous programs and advanced scheduling strategies with MPMD (1F1B, Interleaved, ZBV Zero Bubble) across pipeline stages
- Leverage PyTorch's native distributed pipelining APIs
- Minimize model changes for users
In the image below, we see the MPMD+SPMD execution strategy on the RHS. For each pipeline stage/rank, we still require the XLA compiler to partition the single-device program with the needed CC ops, based on the GSPMD sharding annotations. In order to enable heterogeneous use cases and more complex scheduling strategies, we allow users to split their global devices (& pipeline stages) in separate SPMD single-device programs, invoking the MPMD execution strategy across SPMD (local) worlds.
Motivation
Existing SPMD-only Pipeline Parallelism implementations (mainly JAX’s AXLearn [1], Praxis [2]) have fundamental limitations. They require redundant computation across all ranks to maintain SPMD semantics, making pipeline parallelism inefficient and unable to easily support advanced scheduling strategies besides GPipe. As model sizes increase, memory and communication bandwidth become critical bottlenecks. Our hybrid approach, tied with PyTorch’s native distributed pipelining APIs aims to substantially improve performance, drawing from recent research results [3].
There have been MPMD-only RFCs to originally introduce Pipeline Parallelism for PyTorch/XLA, and upstream to PyTorch (https://github.com/pytorch/xla/issues/6347). That effort should become relatively easier to achieve as we close some of the overlapping gaps with XLA. It is a subset of this RFC/project, and the design/implementation should not build solutions that are not compatible without SPMD. The end-to-end testing and validation is immediately out-of-scope, but will likely be included - as we don’t expect significant nuances if a user were to use single participating device in their (local) SPMD worlds. Similarly, there is no evident risk in enabling SPMD-only PP applications.
Pitch
Native PyTorch
We extend on top of PyTorch's distributed pipelining (formerly PiPPy) solutions, leveraging FX graphs (https://pytorch.org/docs/stable/fx.html#torch.fx.GraphModule) for the model partitioning, and graph specification and execution. We choose to extend on top of the existing library due to:
- Relatively well established infrastructure on PyTorch, leveraging FX graphs for model partitioning
- Support for compiler-specific graph (torch.export) or manual model splitting
- Increasingly larger number of existing scheduling generators
- Declarative schedule format
The goal is to contribute to PyTorch and close parity to work with XLA devices. The partitioned subgraphs are traced and compiled to executable programs to run on XLA devices. We include needed XLA and SPMD specific invariants in the staging of the P2P communication across ranks, factoring in the device mesh (e.g. sharding specifications).
SPMD
We propose a solution that combines the benefits of both SPMD and MPMD paradigms, providing a foundation for future PP work:
- Maintain SPMD semantics (TP, FSDP, DP, CP) for each MPMD execution (PP rank)
- Enable customers to seamlessly use this new paradigm with existing interfaces and scheduling variants
- Ensure compatibility with existing interfaces
- Design for extensibility across PyTorch and JAX ecosystems
This will substantially improve training efficiency for large language models while maintaining compatibility with the broader PyTorch ecosystem.
Other parallelism techniques
Any other parallelism technique with SPMD should work seamlessly with Pipeline Parallelism. The idea is that each PP rank manages an independent SPMD world, and besides P2P communication, all participating (local) devices execute the same program, and CC ops are still generated by XLA. Users can define and map the most optimal training configurations, e.g. TP mapped over high-bandwidth, and DP/PP over low-bandwidth dimensions.
Localized SPMD
PyTorch/XLA supports localized SPMD execution, allowing individual hosts to run SPMD programs on their local devices independently. This feature provides greater flexibility in multi-host environments, especially for MPMD+SPMD workloads. We implement localized SPMD by decoupling logical device indices from physical device IDs and supporting implicit localization within the XLA graph executor. There is already some preliminary work on this from Siyuan, though for a different use case (https://github.com/pytorch/xla/pull/8810).
- Logical Ordinal Abstraction: Within each localized SPMD program, devices are addressed using logical ordinals. Under the hood, the device assignment map to the physical devices that the PjRt client manages.
- Implicit Device Ordinal Mapping: Within each graph executor and device instance, infer the participating devices based on the OpSharding. Invariants are added to ensure that for each SPMD world (replica group), the participating devices are uniform across all XLA tensors.
- Compilation Configuration: During compilation, the XLA compiler is configured to target only the local devices accessible to the current process (separate RFC).
This implementation allows SPMD programs to reason about their execution using a consistent logical device indexing scheme, while the runtime handles the mapping to physical devices. This should honor the distributed process groups and the OpSharding of all live tensors captured by the XLA graph tracing. We ensure that the implementation can serve heterogeneous applications, as well as the intended scope with pipeline parallelism. This needs to come with the capabilities of defining submeshes across all participating devices.
Furthermore, we start with requiring the same number of (local participating) devices for each program for all SPMD worlds. Hence, we maintain a direct mapping from each sharded data between SPMD worlds. There is no need for CC ops to collect/reduce data that is communicated across pipeline stages. The design/implementation should not make incompatible solutions that would complicate the extension to heterogeneous (local) SPMD worlds, namely gathering/reducing/resharding tensors preceding/following P2P communications.
Process groups
Currently, PyTorch/XLA does not support XLA backend with SPMD, since there is a single replica and CC ops are generated by XLA’s Partitioner. We work on relaxing this constraint, in order to simplify the interfaces with SPMD+MPMD, so we enable both native XLA backend for torch.distributed (P2P comm), and other distributed APIs (DTensor). The former dictating the participating ranks for any given submesh, and the latter serving mainly to simplify/improve the user interface, allowing users to define sub-meshes and abstracting the MPMD+SPMD semantics.
Model initialization
We want to ensure that the model initialization on PyTorch/XLA does not capture the entire model on the device for each stage. In order to avoid this, we want to extend the existing meta device context to only capture the metadata, ensuring that we can use with the XLA device interchangeably. Once a model is traced, we can then send the sharded data to the device. Another consideration to account for is that the RNG seeds are correctly generated for each rank, as each rank should have its own unique RNG seed.
XLA
Currently, XLA has no notion of MPMD+SPMD, and hence requires a single replica group with N participants for the entire SPMD world. It requires the ranks to be ordinal 0-based indexing and the sharding annotations to be indicative of the entire SPMD world for the SPMD partitioner. This requires an orthogonal RFC, and will be evaluated in parallel, namely for PJRT client logical ID mappings, HLO specification proposals, and threaded scheduling for P2P communications.
XLA + SPMD:
import os
import sys
from typing import Optional
import numpy as np
import torch
from torch import nn
import torch.optim as optim
+ import torch_xla.core.xla_model as xm
+ import torch_xla.runtime as xr
+ import torch_xla.distributed.spmd as xs
+ import torch_xla.distributed.xla_backend
+ import torch_xla.distributed.parallel_loader as pl
import args_parse
import torch.distributed as dist
from pippy import pipeline, SplitPoint, ScheduleGPipe, PipelineStage
MODEL_OPTS = {
'--input_dim': {
'type': int,
'default': 16834,
},
'--train_dataset_len': {
'type': int,
'default': 1024 * 8,
},
'--pipeline_chunks': {
'type': int,
'default': 4,
}
}
FLAGS = {}
+ xr.use_spmd()
class SimpleLinear(nn.Module):
NUM_CLASSES = 3
def __init__(self):
super().__init__()
# Instead of Sequential, define layers separately for easier split points
self.layer0 = nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2)
self.relu = nn.ReLU()
self.layer1 = nn.Linear(FLAGS.input_dim // 2, 3)
self.layer2 = nn.Linear(3, self.NUM_CLASSES)
def forward(self, x):
x = self.layer0(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
return x
def train():
+ # Torchrun is needed for Pipeline Parallelism by default. Generally, we
+ # don't need it for SPMD, and we could rely on `process_index` and
+ # `addressable_runtime_device_count` from PjRT runtime. However, it would
+ # be needed if we have multiple SPMD worlds within the same physical machine.
+ # Hence, we retain the requirement, and can relax it later on.
+ rank = int(os.environ["RANK"]) # or xr.process_index()
+ world_size = int(os.environ["WORLD_SIZE"]) # or xr.addressable_runtime_device_count()
- rank = int(os.environ["RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
chunks = FLAGS.pipeline_chunks
- if torch.cuda.is_available():
- device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
- else:
- device = torch.device("cpu")
+ # Use XLA device
+ device = xm.xla_device()
print(f"Rank {rank} using device {device}")
+ # (Preferred) Leverage the DTensor/DeviceMesh variants for a more seamless
+ # user interface with submeshes.
+ # -----
+ global_mesh = init_device_mesh("xla", (chunks, num_devices, 1),
+ mesh_dim_names=("pp", "data", "model"))
+ local_mesh = global_mesh["data", "model"]
+ # -----
+ # Alternatively:
+ # -----
+ num_devices = xr.global_runtime_device_count()
+
+ # Global submesh
+ global_mesh_shape = (chunks, num_local_devices, 1)
+ global_mesh = Mesh(np.arange(num_devices), global_mesh_shape, ("pp", "data", "model"))
+
+ # Local submesh
+ num_local_devices = xr.addressable_runtime_device_count()
+ device_id_start = rank * num_local_devices
+ local_device_ids = np.arange(device_id_start, device_id_start + num_local_devices)
+ local_mesh_shape = global_mesh_shape[1:]
+ local_mesh = Mesh(local_device_ids, local_mesh_shape, ("data", "model"))
+ # -----
# Initialize process group
- dist.init_process_group(rank=rank, world_size=world_size)
+ dist.init_process_group(
+ backend="xla",
+ init_method="xla://",
+ rank=rank,
+ world_size=world_size
+ )
+
torch.manual_seed(42)
model = SimpleLinear().to(device)
+ # Shard the model weights as needed:
+ # parallelize_model(model, local_mesh)
# Define split points for pipeline parallelism
split_spec = {
"layer0": SplitPoint.END,
"layer1": SplitPoint.END,
}
# Create a sample input for the pipeline
batch_size = FLAGS.batch_size
example_input = torch.randn(batch_size, FLAGS.input_dim, device=device)
# Create the pipeline and respective stage for the rank.
pipe = pipeline(model, chunks, example_args=(example_input,), split_spec=split_spec)
stage = PipelineStage(pipe, rank, device)
schedule = ScheduleGPipe(stage, chunks)
# Training loop
losses = []
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr)
for epoch in range(FLAGS.num_epochs):
for step, (data, target) in enumerate(data_generator()):
if rank == 0:
+ xs.mark_sharding(data, local_mesh, ('data', 'model')
+ # or distribute_tensor(data, local_mesh, [Shard(0), Shard(1)])
schedule.step(data)
optimizer.zero_grad()
else:
output = schedule.step()
# Only the last rank computes loss and does backward
if rank == world_size - 1:
loss = loss_fn(output, target)
losses.append(loss.clone().detach())
loss.backward()
optimizer.step()
if step % FLAGS.log_steps == 0:
print(f"Epoch {epoch} step {step} loss {loss}")
+ xm.mark_step()
def train():
default_config = {
'batch_size': 128,
'num_epochs': 1,
'lr': 0.1,
'log_steps': 8,
'opts': MODEL_OPTS.items()
}
global FLAGS
FLAGS = args_parse.parse_common_options(**default_config)
print('Start training loop...')
train()
dist.destroy_process_group()
[1] https://github.com/apple/axlearn/blob/main/axlearn/common/pipeline.py [2] https://github.com/google/praxis/blob/main/praxis/layers/pipeline.py [3] https://arxiv.org/abs/2412.14374 [4] https://arxiv.org/abs/2203.12533
cc: @tengyifei @wconstab @kwen2501
Thank you for sharing this RFC!
We are reviewing it. cc @lsy323 @tengyifei @haifeng-jin @bhavya01 @pgmoka @GleasonK
We had minimally prototyped it with XLA + SPMD - mimicking communication across heterogeneous SPMD worlds, minus the native distributed pipelining APIs:
- SPMD localization (similar functionality to Siyuan's local CR)
- Inferring SPMD localization from all live tensors, abiding by the local SPMD mesh
- Enable distributed process groups for XLA + SPMD
- Injecting P2P communication as custom ops with the XLA backend
- Manually transposing P2P custom ops to the appropriate CC op after the SPMD partitioner (custom PJRT pass)
The latter 2 were Neuron-specific, but I'll adapt and share as part of the RFC as well as suggested by Siyuan.
Thanks for the RFC! This seems quite a lot of work and many details need to be filled in, e.g., how to support profiling.
Got a few questions:
There is no need for CC ops to collect/reduce data that is communicated across pipeline stages. How do you plan to sync input/output between different SPMD stages?
On the top level. I don't quite understand the benefit of the SPMD+MPMD solution in the RFC compared with the MPMD solution. I think leave everything to the XLA compiler should be much easier and more efficient for overlapping communication overhead. Can you elaborate more on the motivation? Thanks!
Thanks @zpcore!
How do you plan to sync input/output between different SPMD stages?
In the RFC above, we started focusing on local SPMD worlds that have same number of participants and localized SPMD meshes. Communication occurs across SPMD local ranks: For example, with 2 stages and 8 devices each, we'd have P2P communication as [[0, 8], [1, 9], ..., [7, 15]], transferring sharded data across stages (with receiving rank using device data IR placeholders). We're also careful not to make decisions that would prevent future extension to heterogeneous local SPMD worlds (different meshes or participant counts) - which would require synchronization before or after stages.
On the top level question, capturing all different variants for the wider picture:
-
MPMD-only: We sacrifice the benefits of the GSPMD model strategy, as users must compromise to enable advanced scheduling strategies. This approach requires using MPMD for all parallelism strategies (TP, DP, SP, CP, FSDP), whereas we want users to still benefit from lightweight annotations with XLA's partitioner handling the communication collective placements for the SPMD worlds. Users would need to manually add all communication collective placements, plus having to maintain the 1 process per device model.
-
SPMD-only: Advanced scheduling strategies become unavailable, as we're limited to the GPipe variant. Pipeline parallelism is fundamentally challenging with this model strategy. GPipe requires homogeneous stages and padding the entire model computation, adding superfluous calculations to each rank - sometimes non-negligibly (LMHead, Embedding, etc). OpenXLA's potential work on reducing superfluous computation may help, but the requirement remains problematic. Recent presentations showed execution/communication overlap optimizations in OpenXLA for GPU with SPMD-only GPipe, achieving nice 1.67x and 1.14x improvements for Llama3.1 70B and 405B over GPipe's prior baseline (sequential computation and communication). However, these are limited to this basic scheduling strategy with suboptimal memory/performance characteristics.
We don't limit opportunities to optimize communication overlaps in the XLA compiler, as we can leverage insights from both model strategies. This will be addressed in OpenXLA RFC follow-ups, with clear separation between SPMD and MPMD communication collective operations, which can conceptually help with execution sequence dependencies, communication overlapping, and optimal computation scheduling. This work can be adapted when supporting MPMD+SPMD strategies, even without pipeline parallelism (e.g., heterogeneous SPMD computation programs). We can work towards feature parity for the two approaches, as there's substantial implementation (details) overlap - but that's not in scope.
The MPMD+SPMD strategy combines advantages from both models. Recent studies with Pathways and JaxPP [4 and 3, above] have extended JAX with MPMD support, though they subsequently focused on other aspects of the architecture stack. I tried to keep the motivation concise, but I can move some of these to the RFC if these provide a better comparative picture.
It would be great to discuss the runtime requirements of this work in the context of multi-host setting.
@rpsilva-aws High level questions:
-
Are you proposing to have PyTorch/XLA capture the collectives used for Pipeline Parallelism into HLO ops? How would you run those ops?
-
What is "P2P" communication? Taken literally, that means point-to-point communication between two devices. But I'm not sure that's what you meant?
@rpsilva-aws
We had minimally prototyped it with XLA + SPMD - mimicking communication across heterogeneous SPMD worlds, minus the native distributed pipelining APIs.
Is this prototype on github? I would be curious on taking a look.
Sorry for the late response, I had a few other urgent items. Thanks for having a look! I'll start revisiting the item and questions.
Are you proposing to have PyTorch/XLA capture the collectives used for Pipeline Parallelism into HLO ops? How would you run those ops?
Correct, we would use the XLA backend to intercept the send/recv collectives (from PT's native distributed pipelining APIs) and write the respective HLO ops. We can easily relax the process group limitation with XLA + SPMD and do this. I can likely send out a PR to already relax this invariant and write a couple of tests to showcase this. We're already hoping to unify native distributed APIs with XLA, so this is a good step.
What is "P2P" communication? Taken literally, that means point-to-point communication between two devices.
Yes, that is the intended terminology here. However, I use it precisely to indicate communication between 2 devices in different SPMD worlds / PP ranks. We can refer to the example I mentioned above. I agree with you that this could be clarified or made more explicit - since that's not necessarily the only use case. Similarly, I'll start modularizing all the items and introduce well defined functionality and test cases. We can start with MPMD-only communication tests with XLA backend, followed by SPMD.
Is this prototype on github? I would be curious on taking a look.
The prototype was made with Neuron, and I had a custom PJRT pass to circumvent the global logical indices, but I will follow-up in branching that pass and sharing the end-to-end.
Also interested in seeing the prototype - curious if it was it just MPMD, or did you launch processes that have local device assignments with N devices?
I'm also curious how much your efforts work with / don't work with something like IFRT::CallOp, I.e. ifrt.Call @pipeline_stage1(data) {devices = [0,1,2,3]}, where we keep everything single process, but capture "worlds" explicitly in the IR and build orchestration into the {xla/aws} plugin runtimes?
The [[0, 8], [1, 9], ..., [7, 15]] approach you mentioned sounds like the XLA:GPU pipelining's collective_permute approach, but that operates within a sub-world as you mentioned, interested in your thoughts on how we best integrate this in PyTorch Native, i.e. PT launches processes with N local devices and we (somehow) figure out the device assignments of the next stage, or capture MPMD in the IR and build something that handles programs from there.
@pgmoka and I are exploring some design options like this, hopefully can share a doc soon and iterate with you from there 🙂 .
I would like to confirm my understanding of Model initialization:
From my reading, you suggest that we change data loading in the case of SPMD+MPMD such that we only load data on the device after we have traced. My assumption is that this is because need the trace to tell what data will be utilized by what SPMD world.
Is this understanding correct, or are there things I am overlooking?
I would also be curious if you could go a little deeper on meta device context to only capture the metadata. What metadata are you more specifically thinking of here?
We had a great discussion, and I really appreciate your feedback, Kevin and Pedro! I will follow up on breaking down some of the items on our side, so it's clear(er) on where we stand and how we're parallelizing the workstream on our side. We can align on the granular items.
Also interested in seeing the prototype - curious if it was it just MPMD, or did you launch processes that have local device assignments with N devices?
It was the latter, though we had a bootstrapped version of the localized SPMD effort that Pedro is working on.
I'm also curious how much your efforts work with / don't work with something like IFRT::CallOp, I.e. ifrt.Call @pipeline_stage1(data) {devices = [0,1,2,3]}, where we keep everything single process, but capture "worlds" explicitly in the IR and build orchestration into the {xla/aws} plugin runtimes? The [[0, 8], [1, 9], ..., [7, 15]] approach you mentioned sounds like the XLA:GPU pipelining's collective_permute approach, but that operates within a sub-world as you mentioned, interested in your thoughts on how we best integrate this in PyTorch Native, i.e. PT launches processes with N local devices and we (somehow) figure out the device assignments of the next stage, or capture MPMD in the IR and build something that handles programs from there.
We had definitely considered it, and your comment inspired us to dig more on the exact details. Glad we had the chance to discuss it - agreeing that it can be a parallel or a follow-up item we can chase since we foresee higher initial effort on the IFRT side, the gaps behind torch.export (BWD) for training, and how it would integrate natively with PyTorch APIs - since we also don't necessarily require single-controller runtime applications at first. Most of the overlapping work should aim to be agnostic to these semantics, and we see no concerns on extending to it.
We most definitely want to pursue this down the line, particularly when some of the dependent work lands, so looking forward to it!
Adding it here some of the summary for completeness:
""" Currently, PT/XLA has not yet adopted IFRT entirely. It still uses the original device runtime low-level interfaces that serves to merely abstract devices for any given underlying hardware with basic APIs. IFRT (1, 2, 3) is core to JAX, and allows user-facing frameworks to offer more complex, and wider use cases when interacting with the Plugin client. For instance, it provides the means to view global programs and management across hosts. These limitations also prevent us from seamlessly using certain APIs that would unlock single-controller applications with PP, since we could use existing IFRT op dialects to execute/compile certain programs with an explicitly defined set of “local” devices. These would be embedded directly on the IR, and directly handled by the respective runtime. For instance, we can annotate the devices that participate in any given part of the program, using this (global) operation to distribute the execution across ranks.
func.func @some_op(%arg0: tensor<32x64xf32>, %arg1: tensor<64x32xf32>) -> tensor<32x32xf32> {
%0 = "mhlo.some_op"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x32xf32>) -> tensor<32x32xf32>
return %0 : tensor<32x32xf32>
}
%result = "ifrt.call"(%input1, %input2) {
callee = @some_op,
devices = %devices, # devices=[0,1,2,3]
io_aliases = [[0, 0]], // Alias input 0 with output 0
donated_input_indices = [1] // Input 1 can be donated
} : (tensor<128x64xf32>, tensor<64x128xf32>) -> tensor<128x128xf32>
Similarly, one could also use Ray to similarly achieve remote single-controller runtime applications. Ray offers an unified compute framework for distributed workloads, and already offers a config interface with torch-xla. Ray consists of a main drriving process that runs the workloads, can automatically initiate and manage several actors, each with their own device assignments. Interestingly, the existing interface relies on the distributed process groups to differentiate different backends.
In general, we consider the single- and multi- controller discussion to be orthogonal, and either options can be built on top of the foundational PyTorch integration with its native PP APIs. """
From my reading, you suggest that we change data loading in the case of SPMD+MPMD such that we only load data on the device after we have traced. My assumption is that this is because need the trace to tell what data will be utilized by what SPMD world. Is this understanding correct, or are there things I am overlooking?
Correct. We have the same flow as native PyTorch - we use the existing APIs with torch.export under the hood to fx Graph, and used the split sub modules (based on the schedule), to retrace and lower the resulting modules.
I would also be curious if you could go a little deeper on meta device context to only capture the metadata. What metadata are you more specifically thinking of here?
We are finalizing some of the remaining tweaks and we will share a detailed doc with you folks soon. In short, we are exploring using it to represent the preliminary setup of operations needed for the communication across ranks, including types, shapes, etc.
From https://github.com/pytorch/xla/issues/9019#issuecomment-2938686214. We are looking at Ray vs IFRT. I am wanting to write an RFC on the topic as I think it is a subject worth diving deeper into.
I am also thinking of writing an RFC on Local SPMD just to finish bringing everything that needs to be done there together.
Generally, this is a large RFC that we can split into workstreams. The ones I currently see are:
- Local SPMD
- Lazy tensor loading (model initialization).
- Controller (Ray or IFRT based)
For either Ray or IFRT based controller solutions, I believe we will need Local SPMD and lazy tensor loading.
I have made some investigation into Local SPMD. I am thinking of writing an RFC that outlines all the changes, and then finish spinning bugs for it as an update to https://github.com/pytorch/xla/issues/9181.