Gradient norm clipping with pipeline parallelism (PP)
Dear torchtitan team, I have a question regarding gradient norm clipping when using pipeline parallelism (PP) potentially combined with FSDP/DP/TP.
For simplicity, let's assume each process/GPU has single PP stage. My understanding is that since the model is manually sharded, calling torch.nn.utils.clip_grad_norm_ will only compute the grad norm based on the modules of the current PP stage.
https://github.com/pytorch/torchtitan/blob/eef8bb2b1b6f0875ab0581079e1511d51654910e/train.py#L298-L302
Since grad norm clipping requires computing the norm over the entire model (across all PP stages), does it mean we need to manually aggregate/reduce the grad norm across PP stages before the normalization? If so, what would be the correct approach for doing this?
Any clarification or guidance would be greatly appreciated!
I agree with this.
cc: @wconstab @H-Huang we need to discuss how we should do cilp_grad_norm_ with PP. Given our current design, we cannot solely rely on nn.utils.clip_grad_norm_. Each parameter DTensor will only have placements for FSDP and TP, not PP, so DTensor op dispatch is not aware of PP.
The easiest solution is writing a custom clip_grad_norm_ once again, but maybe some other DTensor machinery can help here. cc: @tianyu-l @XilunWu
@awgu thank you so much for the follow up!
I guess some naive implementations like the below example should work but would appreciate your feedback.
cc @wconstab @H-Huang @tianyu-l @XilunWu
The below implementation is based on torch v2.4.0's torch.nn.utils.clip_grad_norm_
from typing import Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch import distributed as dist
from torch.distributed import DeviceMesh
from torch.distributed._tensor import DTensor, Replicate
from torch.utils._foreach_utils import (
_device_has_foreach_support,
_group_tensors_by_device_and_dtype,
_has_foreach_support,
)
@torch.no_grad()
def clip_grad_norm_(
parameters: Union[Tensor, Iterable[Tensor]],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: Optional[bool] = None,
pp_mesh: Optional[DeviceMesh] = None,
) -> Tensor:
if pp_mesh is None:
return torch.nn.utils.clip_grad_norm_(
parameters,
max_norm=max_norm,
norm_type=norm_type,
error_if_nonfinite=error_if_nonfinite,
foreach=foreach,
)
if isinstance(parameters, Tensor):
parameters = [parameters]
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.0)
first_device = grads[0].device
grouped_grads: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
] = _group_tensors_by_device_and_dtype(
[grads]
) # type: ignore[assignment]
norms: List[Tensor] = []
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
norms.extend(torch._foreach_norm(device_grads, norm_type))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
total_norm = torch.linalg.vector_norm(
torch.stack([norm.to(first_device) for norm in norms]), norm_type
)
# ----- start modified from torch.nn.utils.clip_grad_norm_ -----
if isinstance(total_norm, DTensor):
# if total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`
# we can simply reduce the DTensor to get the total norm in this tensor's process group
# and then convert it to a local tensor
total_norm = total_norm.redistribute(
placements=[Replicate()] * total_norm.device_mesh.ndim
).to_local()
total_norm **= norm_type
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
total_norm **= 1.0 / norm_type
# ----- end modified from torch.nn.utils.clip_grad_norm_ -----
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in device_grads:
g.mul_(clip_coef_clamped_device)
return total_norm
I believe local_map is a good fit for this case, to implement a custom clip_grad_norm_ for DTensor. @zijian-hu let me draft a PR based on your sample so that we can discuss. BTW, should this be applied to all norm ops as well?
update: actually it's not local_map but similarly a reversed version of local_map.
I believe
local_mapis a good fit for this case, to implement a customclip_grad_norm_for DTensor. @zijian-hu let me draft a PR based on your sample so that we can discuss. BTW, should this be applied to allnormops as well?update: actually it's not
local_mapbut similarly a reversed version oflocal_map.
@XilunWu it would be great if you could let me know what other norm ops you were referring to. For RMS norm and layer norm, they are performed within PP stage so the above aggregation/reduction across PP stages is not needed.
I believe only grad_norm need to be reduced since it needs to be computed across all the parameters of the model. In PP, this would require additional reduction/aggregation.
@zijian-hu you're right. I realized that after chatting with @tianyu-l .
draft a PR based on your sample so that we can discuss
@XilunWu is there a draft of this implementation somewhere?
@H-Huang in case Xilun is busy with other work items, I am more than happy draft this PR later today
@H-Huang Not yet, there're something unclear on the DTensor API design side. @zijian-hu really appreciate the offering. We can review your PR first and land if it looks good. If the DTensor side design is ready, then I can migrate your change to use the new API.
@XilunWu @H-Huang the PR has been created https://github.com/pytorch/torchtitan/pull/649
fixed by https://github.com/pytorch/torchtitan/pull/649