TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

How can we use te.Linear with weight parallel?

Open zigzagcai opened this issue 9 months ago • 3 comments

Hi developers,

Thanks for introducing such a great project that enables FP8 training.

In my training framework, we have a weight parallel implementation that do weight all-gather and reduce-scatter like ZeRO3. From the weight parallel implementation we can find that in the forward pass, we all-gather weight do call the linear_forward_op (which is actually torch.nn.functional.Linear).

But when I check the code of te.Linear, there is a torch.autograd.Function named _Linear that handles FP8 computation.

So, I just wonder how can we integrate te.Linear with our weight parallel implementation? From my understanding, the forward op and backward op that used in our weight parallel implementation is dependent on torch.nn.functional.Linear, which is not compatible with the op that used in te._Linear.

Thanks in advance if anybody could provide some hints!

cc @ksivaman @timmoon10 @cyanguwa

zigzagcai avatar Mar 04 '25 12:03 zigzagcai

PyTorch FSDP gathers the module params before each forward and backward so that module implementations can just access them like normal. I wonder if your framework could use a similar approach, perhaps using PyTorch module hooks, e.g. all-gather params with a pre-forward callback and deallocating with a post-forward callback. Things get trickier with FP8 and MXFP8 support, since caching the the FP8/MXFP8 weight is an important performance optimization.

If you are just looking for more fine-grained access to our linear layer implementation, we do have some functional APIs: https://github.com/NVIDIA/TransformerEngine/blob/2ad5da952e42c6fe7bd09bee8810f7f6c195cbd8/transformer_engine/pytorch/ops/basic/basic_linear.py#L335 https://github.com/NVIDIA/TransformerEngine/blob/2ad5da952e42c6fe7bd09bee8810f7f6c195cbd8/transformer_engine/pytorch/ops/basic/basic_linear.py#L539 These are experimental though and we can't make any guarantees on the stability of their APIs.

timmoon10 avatar Mar 07 '25 20:03 timmoon10

PyTorch FSDP gathers the module params before each forward and backward so that module implementations can just access them like normal. I wonder if your framework could use a similar approach, perhaps using PyTorch module hooks, e.g. all-gather params with a pre-forward callback and deallocating with a post-forward callback. Things get trickier with FP8 and MXFP8 support, since caching the the FP8/MXFP8 weight is an important performance optimization.

If you are just looking for more fine-grained access to our linear layer implementation, we do have some functional APIs:

TransformerEngine/transformer_engine/pytorch/ops/basic/basic_linear.py

Line 335 in 2ad5da9

def _functional_forward(

TransformerEngine/transformer_engine/pytorch/ops/basic/basic_linear.py

Line 539 in 2ad5da9

def _functional_backward(

These are experimental though and we can't make any guarantees on the stability of their APIs.

Hi @timmoon10 ,

Thanks for your reply! I have tried your approach, where I switch the default linear fwd/bwd ops with TransformerEngine BasicLinear._functional_forward and BasicLinear._functional_backward. But from the trace, I cannot found any FP8 GEMM kernels. It seems the _functional_forward and _functional_backward still calls BF16 GEMM kernels, not the FP8 GEMM.

class WPFusedDenseFunc(torch.autograd.Function):
    "FusedDenseFunc for weigth parallel, which is optimized based on flash implementation."

    @staticmethod
    @custom_fwd
    def forward(
        ctx,
        x: torch.Tensor,
        weight: torch.Tensor,
        bias: Optional[torch.Tensor],
        module: nn.Module,
        communicator: WPCommunicator,
        return_residual=False,
    ):
        ctx.compute_weight_gradient = weight.requires_grad
        ctx.return_residual = return_residual
        ctx.module = module
        ctx.communicator = communicator
        
        assert bias is None
        assert not return_residual

        if torch.is_autocast_enabled():
            x = x.to(dtype=torch.get_autocast_gpu_dtype())
        x = x.contiguous()

        total_weight = communicator.weight_hook(weight, module=module)
        total_bias = bias if bias is None else communicator.weight_hook(bias, module=module, is_bias=True)

        if torch.is_autocast_enabled():
            total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype())
            if total_bias:
                total_bias.to(dtype=torch.get_autocast_gpu_dtype())

        total_weight = total_weight.contiguous()
        batch_shape, n = x.shape[:-1], x.shape[-1]
        batch_dim = batch_shape.numel()
        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
        if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
            raise RuntimeError("fused_dense only supports matrix dims <= 2M")

        output, _, _ = BasicLinear._functional_forward(input=x, weight=total_weight, bias=total_bias)

        # release memory
        del total_weight
        del total_bias

        # parallel strategy-specific communication callback 2.
        # see more details in the communicator for different parallel strategies.
        # gather seq dim when head parallel_output is False
        if hasattr(communicator, "output_hook"):
            output, _ = communicator.output_hook(output, async_op=False)

        saved_x = None if ctx.compute_weight_gradient is False else x
        ctx.save_for_backward(saved_x, weight, bias)
        
        return output if not return_residual else (output, x)

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output, *args):
        module: nn.Module = ctx.module
        communicator: WPCommunicator = ctx.communicator
        x, weight, bias = ctx.saved_tensors

        # parallel strategy-specific communication callback 3.
        # see more details in the communicator for different parallel strategies.
        if hasattr(communicator, "grad_output_hook"):
            grad_output, _ = communicator.grad_output_hook(grad_output, async_op=False)

        grad_output = grad_output.contiguous()
        if ctx.return_residual:
            (grad_input,) = args
            grad_input = grad_input.contiguous()

        batch_shape = grad_output.shape[:-1]
        batch_dim = batch_shape.numel()
        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])

        total_weight = communicator.weight_hook(weight, module=module)

        # compute weight grad
        if ctx.needs_input_grad[1]:
            assert ctx.compute_weight_gradient
            x = x.reshape(batch_dim, x.shape[-1])
            _, grad_weight = BasicLinear._functional_backward(grad_output=grad_output, input=x, weight=total_weight)
            grad_weight, grad_weight_sync = communicator.grad_hook(
                grad_weight, async_op=True, module=module, is_bias=False
            )
        else:
            grad_weight = None
            grad_bias = grad_output if ctx.needs_input_grad[2] else None

        if ctx.needs_input_grad[0]:
            grad_input, _, _ = BasicLinear._functional_forward(input=grad_output, weight=total_weight.t())
            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
        else:
            grad_input = None

        del total_weight

        if ctx.needs_input_grad[1]:
            grad_weight_sync.wait()

        return grad_input, grad_weight, None, None, None, None, None

CPU Trace: Image

CUDA Trace: Image

Could you please share some insights about how to enable FP8 GEMM kernels with this internal API? @timmoon10 Thanks in advance!

zigzagcai avatar Mar 14 '25 07:03 zigzagcai

The basic idea of our ZeRO3 weight parallel implementation: In WPFusedDenseFunc https://github.com/InternLM/InternEvo/blob/feat/refactor-impl/internlm/model/model_ops/modules/linear.py#L171-L315, we all-gather weights in the fwd pass, then all-gather weights and reduce-scatter gradients in bwd pass. And we just apply this customized autograd function to https://github.com/InternLM/InternEvo/blob/feat/refactor-impl/internlm/model/model_ops/modules/linear.py#L532-L678

So, I just wander how could we integrate TE FP8 with our customized ZeRO3 weight parallel implementation?

zigzagcai avatar Mar 17 '25 08:03 zigzagcai