ColossalAI-Examples icon indicating copy to clipboard operation
ColossalAI-Examples copied to clipboard

[enhancement] Examplify `all_reduce()` for tensor_parallel_*

Open ofey404 opened this issue 3 years ago • 3 comments
trafficstars

Tutorial 1D Tensor Parallelism mentioned the use of all_reduce(), but the example attached doesn't show us how to do it.

Quote:

on each processor, then use an all-reduce to aggregate the results as $Z=Y_1B_1+Y_2B_2Z=Y$

So I made this enhancement, to print weight information before and after calling all_reduce().

Output:

Weight of the first linear layer: torch.Size([256, 512])
Weight of the second linear layer: torch.Size([512, 256])
Output of the first linear layer: torch.Size([16, 512])
Output of the second linear layer: torch.Size([16, 256])
Output of the dropout layer: torch.Size([16, 256])
On rank 0, first 10 elements of x:
tensor([-0.1215, -0.3460, -0.2717, -0.0932, -0.4238, -0.0999, -0.0000,  0.2923,
        -0.1130, -0.0000], device='cuda:0', grad_fn=<SliceBackward0>)

On rank 1, first 10 elements of x:
tensor([-0.1215, -0.3460, -0.2717, -0.0932, -0.4238, -0.0999, -0.0000,  0.2923,
        -0.1130, -0.0000], device='cuda:1', grad_fn=<SliceBackward0>)

After `all_reduce()`, first 10 elements of x:
tensor([-0.2431, -0.6920, -0.5434, -0.1864, -0.8475, -0.1998, -0.0000,  0.5845,
        -0.2259, -0.0000], device='cuda:0', grad_fn=<SliceBackward0>)

Output of the all_reduce opration: torch.Size([16, 256])

ofey404 avatar Apr 17 '22 01:04 ofey404

If proper, I could make similar change to remaining tensor_parallel_*.py.

ofey404 avatar Apr 17 '22 01:04 ofey404

Hi @ofey404 thank you for your contribution! @kurisusnowdeng Could you please help review this PR? Thanks.

binmakeswell avatar Apr 18 '22 11:04 binmakeswell

Hi, you don't need to do all-reduce in your cutomized models, as all-reduce is done in col_nn.Linear. See https://github.com/hpcaitech/ColossalAI/blob/91a5999825137ffb4d575b21bf4c6cb41033161a/colossalai/nn/layer/parallel_1d/layers.py#L664

ver217 avatar Jun 15 '22 10:06 ver217