[question]can't disable CP for specific (unsupported) SDPA op
Problem
currently the API of context parallel have five problems.
- only support apply CP to whole model. if we have some cross attn in prep part of model with unsupported shape, it's impossible to apply CP since
_context_parallelalways override all SDPA and need to wrap whole backward. - no shard/unshard with gradient support. when I try to apply CP to transformer blocks only and remain other SDPA replicate, the
context_parallel_unshardin pytorch hasno_graddecorator. - weight gradients inside CP region is divided by size of CP mesh because we reduce them in DP+CP, this may work for optimizer with norm support, but make unit test harder to write, we have to scale them back to get same gradients as model without CP.
- The length of the sequence must be divisible by the number of CP (CP * 2 for robin).
- replicate input of CP region may contain wrong gradient because its gradient may be
Partial, we have to check every replicate input and useto_local(grad_placements=[Partial()]).
To resolve problem 1 above, I remove context_parallel context to disable SDPA override, only enable _enable_cp_dispatcher context, then we can enable CP SDPA iff all inputs are converted to DTensor. problem 2 is easy to resolve, just write some auto grad functions.
here is my questions:
- is there a better way to support
CP region? - do you have any plan to support
CP regionofficially and resolve issues above?
We do see some requests to apply CP to only some region. This will require some communication before SDPA as we will have to borrow some ranks from the DP sharding dimensions. Current CP doesn't support this for sure. We will discuss how and what is the best way to support this feature.
To resolve problem 1 above, I remove context_parallel context to disable SDPA override, only enable _enable_cp_dispatcher context, then we can enable CP SDPA iff all inputs are converted to DTensor.
Hi @FindDefinition , can you expand a little bit on your work around for issue number 1? Thanks a lot.
@lkhphuc there are two dispatch in pytorch context parallel impl:
- dispatch native op (e.g.
aten._scaled_dot_product_flash_attention.default) when input isDTensor, this dispatch only happens when sdpa native op is invoked withDTensorinput. - dispatch
F.scaled_dot_product_attentionto a simple implementation that only convert yourtorch.Tensorinput toDTensorto apply native dispatch in 1
so if we remove second dispatch, F.scaled_dot_product_attention will use non-parallel impl when inputs aren't DTensor, and use context parallel version when you convert qkv to DTensor.
example:
@copy_sig(F.scaled_dot_product_attention)
def scaled_dot_product_attention_parallel(*args, **kwargs) -> torch.Tensor:
"""Context Parallel version of scaled_dot_product_attention.
this function just convert inputs to DTensor and convert result back to local tensor.
DTensor will dispatch correct parallel version of SDPA when inputs are DTensor and
cp context below is entered.
"""
ctx = get_parallel_manager()
if ctx is None or ctx.cp_mesh is None:
return F.scaled_dot_product_attention(*args, **kwargs)
if len(args) == 0:
query = kwargs["query"]
kwargs.pop("query")
else:
query = args[0]
if len(args) <= 1:
key = kwargs["key"]
kwargs.pop("key")
else:
key = args[1]
if len(args) <= 2:
value = kwargs["value"]
kwargs.pop("value")
else:
value = args[2]
placement = Shard(ctx.get_cp_config().seq_dim)
mesh = ctx.cp_mesh
query = DTensor.from_local(query, mesh, [placement], run_check=False)
key = DTensor.from_local(key, mesh, [placement], run_check=False)
value = DTensor.from_local(value, mesh, [placement], run_check=False)
if ctx.get_cp_config().head_parallel:
# tp like impl (head parallel)
query = query.redistribute(mesh, [Shard(1)]).to_local()
key = key.redistribute(mesh, [Shard(1)]).to_local()
value = value.redistribute(mesh, [Shard(1)]).to_local()
# qkv is [B, H, N/CP, D]
# convert it to [B, H/CP, N, D]
res = F.scaled_dot_product_attention(query, key, value, *args[3:], **kwargs)
res_d = DTensor.from_local(res, mesh, [Shard(1)], run_check=False)
res = res_d.redistribute(mesh, [Shard(ctx.get_cp_config().seq_dim)]).to_local()
return res
else:
# ring-attn like impl
# handled by pytorch
res = F.scaled_dot_product_attention(query, key, value, *args[3:], **kwargs)
assert isinstance(res, DTensor)
return res.to_local()
class _CpShardFunction(Function):
@staticmethod
def forward(ctx, buffer, mesh, seq_dim):
from torch.distributed.tensor.experimental._attention import _RoundRobinLoadBalancer, _SequentialSharder, _cp_options
shader_cls = _RoundRobinLoadBalancer if _cp_options.enable_load_balance else _SequentialSharder
ctx.mesh = mesh
ctx.seq_dim = seq_dim
return shader_cls.shard(buffer, mesh, seq_dim)
@staticmethod
def backward(ctx, grad_output):
# do shard
from torch.distributed.tensor.experimental._attention import _RoundRobinLoadBalancer, _SequentialSharder, _cp_options
shader_cls = _RoundRobinLoadBalancer if _cp_options.enable_load_balance else _SequentialSharder
grad = shader_cls.unshard(grad_output, ctx.mesh, ctx.seq_dim)
return grad, None, None
class _CpUnshardFunction(Function):
@staticmethod
def forward(ctx, buffer, mesh, seq_dim):
from torch.distributed.tensor.experimental._attention import _RoundRobinLoadBalancer, _SequentialSharder, _cp_options
shader_cls = _RoundRobinLoadBalancer if _cp_options.enable_load_balance else _SequentialSharder
ctx.mesh = mesh
ctx.seq_dim = seq_dim
return shader_cls.unshard(buffer, mesh, seq_dim)
@staticmethod
def backward(ctx, grad_output):
# do shard
from torch.distributed.tensor.experimental._attention import _RoundRobinLoadBalancer, _SequentialSharder, _cp_options
shader_cls = _RoundRobinLoadBalancer if _cp_options.enable_load_balance else _SequentialSharder
grad = shader_cls.shard(grad_output, ctx.mesh, ctx.seq_dim)
return grad, None, None
def context_parallel_shard(
mesh: DeviceMesh,
buffers: List[torch.Tensor],
seq_dims: List[int],
) -> List[torch.Tensor]:
"""
Shard the tensors (e.g., output) that should be sharded due to context parallelism.
"""
return [_CpShardFunction.apply(b, mesh, dim) for b, dim in zip(buffers, seq_dims)]
def context_parallel_unshard(
mesh: DeviceMesh,
buffers: List[torch.Tensor],
seq_dims: List[int],
) -> List[torch.Tensor]:
"""
Unshard the tensors (e.g., output) that are sharded due to context parallelism.
"""
return [_CpUnshardFunction.apply(b, mesh, dim) for b, dim in zip(buffers, seq_dims)]
class ExampleModule(torch.nn.Module):
def forward(self, x):
# when you use cp without `context_parallel`, you must shard/unshard manually.
x_shard = context_parallel_shard(cp_mesh, [x], [1])[0]
x_shard = transformer_block(x_shard)
x = context_parallel_unshard(cp_mesh, [x_shard], [1])[0]
from torch.distributed.tensor.experimental._attention import _enable_cp_dispatcher
mod = ExampleModule()
with _enable_cp_dispatcher():
# here we don't use `context_parallel`, just use `_enable_cp_dispatcher`.
# keep in mind that you need to shard/unshard tensors manually.
loss = model(x)
loss.backward()
in example above, you can use F.scaled_dot_product_attention when you don't want to apply cp, and use scaled_dot_product_attention_parallel when you need cp.