torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[question]can't disable CP for specific (unsupported) SDPA op

Open FindDefinition opened this issue 11 months ago • 3 comments

Problem

currently the API of context parallel have five problems.

  1. only support apply CP to whole model. if we have some cross attn in prep part of model with unsupported shape, it's impossible to apply CP since _context_parallel always override all SDPA and need to wrap whole backward.
  2. no shard/unshard with gradient support. when I try to apply CP to transformer blocks only and remain other SDPA replicate, the context_parallel_unshard in pytorch has no_grad decorator.
  3. weight gradients inside CP region is divided by size of CP mesh because we reduce them in DP+CP, this may work for optimizer with norm support, but make unit test harder to write, we have to scale them back to get same gradients as model without CP.
  4. The length of the sequence must be divisible by the number of CP (CP * 2 for robin).
  5. replicate input of CP region may contain wrong gradient because its gradient may be Partial, we have to check every replicate input and use to_local(grad_placements=[Partial()]).

To resolve problem 1 above, I remove context_parallel context to disable SDPA override, only enable _enable_cp_dispatcher context, then we can enable CP SDPA iff all inputs are converted to DTensor. problem 2 is easy to resolve, just write some auto grad functions.

here is my questions:

  1. is there a better way to support CP region?
  2. do you have any plan to support CP region officially and resolve issues above?

FindDefinition avatar Dec 20 '24 11:12 FindDefinition

We do see some requests to apply CP to only some region. This will require some communication before SDPA as we will have to borrow some ranks from the DP sharding dimensions. Current CP doesn't support this for sure. We will discuss how and what is the best way to support this feature.

fegin avatar Dec 20 '24 17:12 fegin

To resolve problem 1 above, I remove context_parallel context to disable SDPA override, only enable _enable_cp_dispatcher context, then we can enable CP SDPA iff all inputs are converted to DTensor.

Hi @FindDefinition , can you expand a little bit on your work around for issue number 1? Thanks a lot.

lkhphuc avatar Mar 12 '25 08:03 lkhphuc

@lkhphuc there are two dispatch in pytorch context parallel impl:

  1. dispatch native op (e.g. aten._scaled_dot_product_flash_attention.default) when input is DTensor, this dispatch only happens when sdpa native op is invoked with DTensor input.
  2. dispatch F.scaled_dot_product_attention to a simple implementation that only convert your torch.Tensor input to DTensor to apply native dispatch in 1

so if we remove second dispatch, F.scaled_dot_product_attention will use non-parallel impl when inputs aren't DTensor, and use context parallel version when you convert qkv to DTensor.

example:

@copy_sig(F.scaled_dot_product_attention)
def scaled_dot_product_attention_parallel(*args, **kwargs) -> torch.Tensor:
    """Context Parallel version of scaled_dot_product_attention.
    this function just convert inputs to DTensor and convert result back to local tensor.
    DTensor will dispatch correct parallel version of SDPA when inputs are DTensor and
    cp context below is entered.
    """
    ctx = get_parallel_manager()
    if ctx is None or ctx.cp_mesh is None:
        return F.scaled_dot_product_attention(*args, **kwargs)
    if len(args) == 0:
        query = kwargs["query"]
        kwargs.pop("query")
    else:
        query = args[0]
    if len(args) <= 1:
        key = kwargs["key"]
        kwargs.pop("key")
    else:
        key = args[1]
    if len(args) <= 2:
        value = kwargs["value"]
        kwargs.pop("value")
    else:
        value = args[2]
    placement = Shard(ctx.get_cp_config().seq_dim)
    mesh = ctx.cp_mesh
    query = DTensor.from_local(query, mesh, [placement], run_check=False)
    key = DTensor.from_local(key, mesh, [placement], run_check=False)
    value = DTensor.from_local(value, mesh, [placement], run_check=False)
    if ctx.get_cp_config().head_parallel:
        # tp like impl (head parallel)
        query = query.redistribute(mesh, [Shard(1)]).to_local()
        key = key.redistribute(mesh, [Shard(1)]).to_local()
        value = value.redistribute(mesh, [Shard(1)]).to_local()
        # qkv is [B, H, N/CP, D]
        # convert it to [B, H/CP, N, D]
        res = F.scaled_dot_product_attention(query, key, value, *args[3:], **kwargs)
        res_d = DTensor.from_local(res, mesh, [Shard(1)], run_check=False)
        res = res_d.redistribute(mesh, [Shard(ctx.get_cp_config().seq_dim)]).to_local()
        return res
    else:
        # ring-attn like impl
        # handled by pytorch
        res = F.scaled_dot_product_attention(query, key, value, *args[3:], **kwargs)
        assert isinstance(res, DTensor)
        return res.to_local()

class _CpShardFunction(Function):
    @staticmethod
    def forward(ctx, buffer, mesh, seq_dim):
        from torch.distributed.tensor.experimental._attention import _RoundRobinLoadBalancer, _SequentialSharder, _cp_options
        shader_cls = _RoundRobinLoadBalancer if _cp_options.enable_load_balance else _SequentialSharder
        ctx.mesh = mesh
        ctx.seq_dim = seq_dim
        return shader_cls.shard(buffer, mesh, seq_dim)

    @staticmethod
    def backward(ctx, grad_output):
        # do shard
        from torch.distributed.tensor.experimental._attention import _RoundRobinLoadBalancer, _SequentialSharder, _cp_options
        shader_cls = _RoundRobinLoadBalancer if _cp_options.enable_load_balance else _SequentialSharder
        grad = shader_cls.unshard(grad_output, ctx.mesh, ctx.seq_dim)
        return grad, None, None

class _CpUnshardFunction(Function):
    @staticmethod
    def forward(ctx, buffer, mesh, seq_dim):
        from torch.distributed.tensor.experimental._attention import _RoundRobinLoadBalancer, _SequentialSharder, _cp_options
        shader_cls = _RoundRobinLoadBalancer if _cp_options.enable_load_balance else _SequentialSharder
        ctx.mesh = mesh
        ctx.seq_dim = seq_dim
        return shader_cls.unshard(buffer, mesh, seq_dim)

    @staticmethod
    def backward(ctx, grad_output):
        # do shard
        from torch.distributed.tensor.experimental._attention import _RoundRobinLoadBalancer, _SequentialSharder, _cp_options
        shader_cls = _RoundRobinLoadBalancer if _cp_options.enable_load_balance else _SequentialSharder
        grad = shader_cls.shard(grad_output, ctx.mesh, ctx.seq_dim)
        return grad, None, None


def context_parallel_shard(
    mesh: DeviceMesh,
    buffers: List[torch.Tensor],
    seq_dims: List[int],
) -> List[torch.Tensor]:
    """
    Shard the tensors (e.g., output) that should be sharded due to context parallelism.
    """
    return [_CpShardFunction.apply(b, mesh, dim) for b, dim in zip(buffers, seq_dims)]


def context_parallel_unshard(
    mesh: DeviceMesh,
    buffers: List[torch.Tensor],
    seq_dims: List[int],
) -> List[torch.Tensor]:
    """
    Unshard the tensors (e.g., output) that are sharded due to context parallelism.
    """
    return [_CpUnshardFunction.apply(b, mesh, dim) for b, dim in zip(buffers, seq_dims)]

class ExampleModule(torch.nn.Module):
    def forward(self, x):
        # when you use cp without `context_parallel`, you must shard/unshard manually.
        x_shard = context_parallel_shard(cp_mesh, [x], [1])[0]
        x_shard = transformer_block(x_shard)
        x = context_parallel_unshard(cp_mesh, [x_shard], [1])[0]

from torch.distributed.tensor.experimental._attention import _enable_cp_dispatcher
mod = ExampleModule()
with _enable_cp_dispatcher():
    # here we don't use `context_parallel`, just use `_enable_cp_dispatcher`.
    # keep in mind that you need to shard/unshard tensors manually.
    loss = model(x)
    loss.backward()

    

in example above, you can use F.scaled_dot_product_attention when you don't want to apply cp, and use scaled_dot_product_attention_parallel when you need cp.

FindDefinition avatar Mar 12 '25 10:03 FindDefinition