examples icon indicating copy to clipboard operation
examples copied to clipboard

tensor_parallel_example.py and sequence_parallel_example.py

Open githubsgi opened this issue 4 months ago • 0 comments

The primary difference between the two files are as follows. The TP case , only see 1 allreduce per iteration - is that what is expected ? Seems to be same as DDP ! In the SP case, see 1 allgather and 1 reduce -scatter per iteration.

# Custom parallelization plan for the model
sp_model = parallelize_module(
    module=model,
    device_mesh=device_mesh,
    parallelize_plan={
        "in_proj": ColwiseParallel(input_layouts=Shard(0)),
        "out_proj": RowwiseParallel(output_layouts=Shard(0)),
    },
)

# Custom parallelization plan for the model
tp_model = parallelize_module(
    module=tp_model,
    device_mesh=device_mesh,
    parallelize_plan={
        "in_proj": ColwiseParallel(),
        "out_proj": RowwiseParallel(),
    },
)

CommDebugMode also appears to show 1 allreduce in fwd and no allreduce in bwd.

  FORWARD PASS                                                                                                                                                                                                                 [12/1864]
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])

githubsgi avatar Jun 11 '25 01:06 githubsgi