TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

why close ag overlap when is_grad_enabled is False

Open sallyjunjun opened this issue 11 months ago • 1 comments

When I test model with full recomputation, the forward all-gather communication is not overlapped. Because is_grad_enabled is false when forward with full recomputation. I see the following code in _LayerNormLinear class:

    if ub_overlap_ag:
        tp_world_size = get_distributed_world_size(tp_group)
        if tp_world_size == 1 or (not is_grad_enabled):
            ub_overlap_ag = False
    if ub_overlap_ag:
        dim_size = list(inputmat.size())
        dim_size[0] = dim_size[0] * tp_world_size
        ub_obj_lnout = get_ub(ub_name + "_fprop")
        if return_layernorm_output:
            # First prepare LN output in higher precision,
            # which will be later copied to a FP8 UB
            ln_out = torch.empty_like(inputmat, memory_format=torch.contiguous_format)
        else:
            ln_out = ub_obj_lnout.get_ubuf_output(0)

Why ub_overlap_ag is set to False in '(not is_grad_enabled)' condition?

sallyjunjun avatar Jan 10 '25 02:01 sallyjunjun

Same question, could you explain the reason? I want to use comm-gemm-overlap in the prefilling stage of inference.

KevinZeng08 avatar Feb 28 '25 16:02 KevinZeng08