xla icon indicating copy to clipboard operation
xla copied to clipboard

Pipeline parallelism with SPMD

Open amithrm opened this issue 1 year ago • 3 comments

🚀 The feature, motivation and pitch

Motivation

SPMD sharding in pytorch/XLA offers model parallelism by sharding tensors within an operator. However, we need a mechanism to integrate this capapability with pipeline parallelism for models that are large and cannot use SPMD sharding (using mark_sharding APIs) either for performance reasons or memory constraints.

Pitch

The high level idea is to integrate the pipeline parallel functionality of the existing package with GSPMD https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html As described in the docs, “In NeuronxDistributed, we use Pytorch’s FX to trace the model and do partition on the FX IR. User simply needs to specify where to cut the pipeline stages, and our algorithm will cut the pipeline stages and assign the corresponding modules to each Neuron core automatically.”

Alternatives

No response

Additional context

No response

amithrm avatar Feb 29 '24 05:02 amithrm

A simple example to get the conversation started and use to feature complete.

`

pipeline_cuts=['layers.4']

class SimpleLinear(nn.Module):

  def __init__(self):
    super(SimpleLinear, self).__init__()
    self.fc1 = nn.Linear(FLAGS.input_dim, FLAGS.input_dim * 4, bias=False)
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(FLAGS.input_dim * 4, FLAGS.input_dim, bias=False)

  def forward(self, x):
    y = self.relu(self.fc1(x))
    z = self.fc2(y)
    return z

class StackedLinear(SimpleLinear):

  def __init__(self):
    super(StackedLinear, self).__init__()
    self.layers = nn.ModuleList([SimpleLinear() for _ in range(0, 10)])

  def forward(self, x):
    for i, l in enumerate(self.layers):
       x = self.layers[i].forward(x)
    return x

device = xm.xla_device()

def train():
  num_epochs = 1
  lr = 0.1
  train_loader = xu.SampleGenerator(
      data=(torch.randn(FLAGS.batch_size, 2, FLAGS.input_dim),
            torch.randn(FLAGS.batch_size, 2, FLAGS.input_dim)),
      sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)
  torch.manual_seed(42)
  model = StackedLinear().to(device)

  model = NxDPPModel(
    model,
    transformer_layer_cls=SimpleLinear,
    num_microbatches=FLAGS.batch_size,
    output_loss_value_spec=(True),
    input_names=['x'],
    pipeline_cuts=pipeline_cuts,
    trace_file_path=None,
    leaf_module_cls=None,
    autowrap_modules=None,
    use_zero1_optimizer=True,
   )

  num_devices = NUM_DEVICES
  # Define a mesh with all devices along one axis
  mesh_shape = (1, 32)

  device_ids = np.arange(num_devices)
  mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

  for l in model.layers:
    # Shard the second layer's weights column-wise
    xs.mark_sharding(l.fc1.weight, mesh, ('y', 'x'))
    # Shard the first layer's weights row-wise
    xs.mark_sharding(l.fc2.weight, mesh, ('x', 'y'))

  optimizer = optim.SGD(model.parameters(), lr=lr)

def train_loop_fn(loader, epoch):
    model.train()
    for step, (data, target) in enumerate(loader):
      with xp.StepTrace('train_linear_model'):
        with xp.Trace('build_graph'):
          data = data.to(device)
          target = target.to(device)
          optimizer.zero_grad()
          output = model(data)
          loss = loss_fn(output, target)
          loss.backward()
        optimizer.step()
      xm.mark_step()
      if step % 10 == 0:
        print(f"Epoch {epoch} step {step} loss {loss}")

  for epoch in range(FLAGS.num_epochs):
    train_loop_fn(train_loader, epoch)

  return model

if FLAGS.profile:
  server = xp.start_server(FLAGS.profiler_port)

print('Start training loop...')
m = train()
t = torch.randn(10, FLAGS.input_dim).to(device)
m(t).cpu()

`

amithrm avatar Feb 29 '24 05:02 amithrm

Trying to make this work, hitting into a basic issue, creating a ticket for this: https://github.com/pytorch/xla/issues/6647

amithrm avatar Feb 29 '24 05:02 amithrm

Thanks @amithrm +1 looking forward to pipelining using GSPMD!

yeounoh avatar Mar 02 '24 01:03 yeounoh