TransformerEngine
TransformerEngine copied to clipboard
How can we use te.Linear with weight parallel?
Hi developers,
Thanks for introducing such a great project that enables FP8 training.
In my training framework, we have a weight parallel implementation that do weight all-gather and reduce-scatter like ZeRO3. From the weight parallel implementation we can find that in the forward pass, we all-gather weight do call the linear_forward_op (which is actually torch.nn.functional.Linear).
But when I check the code of te.Linear, there is a torch.autograd.Function named _Linear that handles FP8 computation.
So, I just wonder how can we integrate te.Linear with our weight parallel implementation? From my understanding, the forward op and backward op that used in our weight parallel implementation is dependent on torch.nn.functional.Linear, which is not compatible with the op that used in te._Linear.
Thanks in advance if anybody could provide some hints!
cc @ksivaman @timmoon10 @cyanguwa
PyTorch FSDP gathers the module params before each forward and backward so that module implementations can just access them like normal. I wonder if your framework could use a similar approach, perhaps using PyTorch module hooks, e.g. all-gather params with a pre-forward callback and deallocating with a post-forward callback. Things get trickier with FP8 and MXFP8 support, since caching the the FP8/MXFP8 weight is an important performance optimization.
If you are just looking for more fine-grained access to our linear layer implementation, we do have some functional APIs: https://github.com/NVIDIA/TransformerEngine/blob/2ad5da952e42c6fe7bd09bee8810f7f6c195cbd8/transformer_engine/pytorch/ops/basic/basic_linear.py#L335 https://github.com/NVIDIA/TransformerEngine/blob/2ad5da952e42c6fe7bd09bee8810f7f6c195cbd8/transformer_engine/pytorch/ops/basic/basic_linear.py#L539 These are experimental though and we can't make any guarantees on the stability of their APIs.
PyTorch FSDP gathers the module params before each forward and backward so that module implementations can just access them like normal. I wonder if your framework could use a similar approach, perhaps using PyTorch module hooks, e.g. all-gather params with a pre-forward callback and deallocating with a post-forward callback. Things get trickier with FP8 and MXFP8 support, since caching the the FP8/MXFP8 weight is an important performance optimization.
If you are just looking for more fine-grained access to our linear layer implementation, we do have some functional APIs:
TransformerEngine/transformer_engine/pytorch/ops/basic/basic_linear.py
Line 335 in 2ad5da9
def _functional_forward(
TransformerEngine/transformer_engine/pytorch/ops/basic/basic_linear.py
Line 539 in 2ad5da9
def _functional_backward(
These are experimental though and we can't make any guarantees on the stability of their APIs.
Hi @timmoon10 ,
Thanks for your reply! I have tried your approach, where I switch the default linear fwd/bwd ops with TransformerEngine BasicLinear._functional_forward and BasicLinear._functional_backward. But from the trace, I cannot found any FP8 GEMM kernels. It seems the _functional_forward and _functional_backward still calls BF16 GEMM kernels, not the FP8 GEMM.
class WPFusedDenseFunc(torch.autograd.Function):
"FusedDenseFunc for weigth parallel, which is optimized based on flash implementation."
@staticmethod
@custom_fwd
def forward(
ctx,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
module: nn.Module,
communicator: WPCommunicator,
return_residual=False,
):
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.module = module
ctx.communicator = communicator
assert bias is None
assert not return_residual
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
total_weight = communicator.weight_hook(weight, module=module)
total_bias = bias if bias is None else communicator.weight_hook(bias, module=module, is_bias=True)
if torch.is_autocast_enabled():
total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype())
if total_bias:
total_bias.to(dtype=torch.get_autocast_gpu_dtype())
total_weight = total_weight.contiguous()
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *total_weight.shape) > 65535 * 32:
raise RuntimeError("fused_dense only supports matrix dims <= 2M")
output, _, _ = BasicLinear._functional_forward(input=x, weight=total_weight, bias=total_bias)
# release memory
del total_weight
del total_bias
# parallel strategy-specific communication callback 2.
# see more details in the communicator for different parallel strategies.
# gather seq dim when head parallel_output is False
if hasattr(communicator, "output_hook"):
output, _ = communicator.output_hook(output, async_op=False)
saved_x = None if ctx.compute_weight_gradient is False else x
ctx.save_for_backward(saved_x, weight, bias)
return output if not return_residual else (output, x)
@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
module: nn.Module = ctx.module
communicator: WPCommunicator = ctx.communicator
x, weight, bias = ctx.saved_tensors
# parallel strategy-specific communication callback 3.
# see more details in the communicator for different parallel strategies.
if hasattr(communicator, "grad_output_hook"):
grad_output, _ = communicator.grad_output_hook(grad_output, async_op=False)
grad_output = grad_output.contiguous()
if ctx.return_residual:
(grad_input,) = args
grad_input = grad_input.contiguous()
batch_shape = grad_output.shape[:-1]
batch_dim = batch_shape.numel()
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
total_weight = communicator.weight_hook(weight, module=module)
# compute weight grad
if ctx.needs_input_grad[1]:
assert ctx.compute_weight_gradient
x = x.reshape(batch_dim, x.shape[-1])
_, grad_weight = BasicLinear._functional_backward(grad_output=grad_output, input=x, weight=total_weight)
grad_weight, grad_weight_sync = communicator.grad_hook(
grad_weight, async_op=True, module=module, is_bias=False
)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if ctx.needs_input_grad[0]:
grad_input, _, _ = BasicLinear._functional_forward(input=grad_output, weight=total_weight.t())
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
else:
grad_input = None
del total_weight
if ctx.needs_input_grad[1]:
grad_weight_sync.wait()
return grad_input, grad_weight, None, None, None, None, None
CPU Trace:
CUDA Trace:
Could you please share some insights about how to enable FP8 GEMM kernels with this internal API? @timmoon10 Thanks in advance!
The basic idea of our ZeRO3 weight parallel implementation:
In WPFusedDenseFunc https://github.com/InternLM/InternEvo/blob/feat/refactor-impl/internlm/model/model_ops/modules/linear.py#L171-L315, we all-gather weights in the fwd pass, then all-gather weights and reduce-scatter gradients in bwd pass. And we just apply this customized autograd function to https://github.com/InternLM/InternEvo/blob/feat/refactor-impl/internlm/model/model_ops/modules/linear.py#L532-L678
So, I just wander how could we integrate TE FP8 with our customized ZeRO3 weight parallel implementation?