TransformerEngine
TransformerEngine copied to clipboard
why close ag overlap when is_grad_enabled is False
When I test model with full recomputation, the forward all-gather communication is not overlapped. Because is_grad_enabled is false when forward with full recomputation. I see the following code in _LayerNormLinear class:
if ub_overlap_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled):
ub_overlap_ag = False
if ub_overlap_ag:
dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub(ub_name + "_fprop")
if return_layernorm_output:
# First prepare LN output in higher precision,
# which will be later copied to a FP8 UB
ln_out = torch.empty_like(inputmat, memory_format=torch.contiguous_format)
else:
ln_out = ub_obj_lnout.get_ubuf_output(0)
Why ub_overlap_ag is set to False in '(not is_grad_enabled)' condition?
Same question, could you explain the reason? I want to use comm-gemm-overlap in the prefilling stage of inference.