usage with pytorch FSDP
How to use with pytorch FSDP for 2D parallelism training. For example we want to apply USP attention on single nodes and FSDP across multiple nodes
The parallel group for USP and FDSP should be the same. You can wrap the USP applied module with FSDP.
I've successfully implemented and tested the PyTorch Device Mesh API for hybrid parallelism. Here's my working implementation that demonstrates the device mesh setup with data parallelism (DP), ring parallelism, and Ulysses parallelism
- The current implementation uses
DeviceMeshto create a 3D parallel structure - Successfully tested with Flash Attention and validated outputs against non-distributed reference implementation
- The mesh structure is configured as
(dp_degree, sp_ring_degree, sp_ulysses_degree)with corresponding process groups - It seems to be working for --sp_ulysses_degree 2 --sp_ring_degree 4 --dp_degree 1 --use_ulysses_low --ring_impl_type "basic"
I'm now looking to integrate FSDP (Fully Sharded Data Parallel) with this setup. My understanding is that FSDP should be applied over the data parallel group (dp_group) created by the device mesh. Is this the correct approach?
import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
from fastcore.script import call_parse
from flash_attn import flash_attn_func
from loguru import logger
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from yunchang import EXTRACT_FUNC_DICT, AsyncLongContextAttention, LongContextAttention, set_seq_parallel_pg
from yunchang.globals import PROCESS_GROUP
from yunchang.kernels import AttnType
@dataclass
class ParallelDims:
sp_ulysses_degree: int = 1
sp_ring_degree: int = 1
dp_degree: int = 1
def build_device_mesh(self, device_type: str = "cuda", use_ulysses_low: bool = True) -> DeviceMesh:
if use_ulysses_low:
mesh_shape = (self.dp_degree, self.sp_ring_degree, self.sp_ulysses_degree)
mesh_dim_names = ("dp", "ring", "ulysses")
else:
mesh_shape = (self.dp_degree, self.sp_ulysses_degree, self.sp_ring_degree)
mesh_dim_names = ("dp", "ulysses", "ring")
return init_device_mesh(device_type=device_type, mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names)
@call_parse
def main(
sp_ulysses_degree: int = 1,
sp_ring_degree: int = 1,
dp_degree: int = 1,
use_ulysses_low: bool = True,
ring_impl_type: str = "basic",
):
world_size = int(os.environ["WORLD_SIZE"])
logger.info(f"world_size: {world_size}")
torch.distributed.init_process_group(backend="nccl")
rank = torch.distributed.get_rank()
logger.remove()
logger.add(sys.stdout, level="INFO", format=f"[ rank={rank} ]" + "{time} {level} {message}", colorize=True)
parallel_dims = ParallelDims(
sp_ulysses_degree=sp_ulysses_degree,
sp_ring_degree=sp_ring_degree,
dp_degree=dp_degree,
)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
world_mesh = parallel_dims.build_device_mesh(device_type=device_type, use_ulysses_low=use_ulysses_low)
dp_mesh = world_mesh["dp"]
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
logger.info(f"dp_degree: {dp_degree}, dp_rank: {dp_rank}")
ring_mesh = world_mesh["ring"]
ring_degree, ring_rank = ring_mesh.size(), ring_mesh.get_local_rank()
logger.info(f"ring_degree: {ring_degree}, ring_rank: {ring_rank}")
ulysses_mesh = world_mesh["ulysses"]
ulysses_degree, ulysses_rank = ulysses_mesh.size(), ulysses_mesh.get_local_rank()
logger.info(f"ulysses_degree: {ulysses_degree}, ulysses_rank: {ulysses_rank}")
PROCESS_GROUP.ULYSSES_PG = world_mesh["ulysses"].get_group()
PROCESS_GROUP.RING_PG = world_mesh["ring"].get_group()
batch_size = 1
seqlen = 4096
nheads = 12
d = 128
dropout_p = 0
deterministic = False
q = torch.randn(batch_size, seqlen, nheads, d, device=torch.device(f"cuda:{rank}"), requires_grad=True, dtype=torch.bfloat16)
k = torch.randn(batch_size, seqlen, nheads, d, device=torch.device(f"cuda:{rank}"), requires_grad=True, dtype=torch.bfloat16)
v = torch.randn(batch_size, seqlen, nheads, d, device=torch.device(f"cuda:{rank}"), requires_grad=True, dtype=torch.bfloat16)
dout = torch.randn(batch_size, seqlen, nheads, d, device=torch.device(f"cuda:{rank}"), requires_grad=True, dtype=torch.bfloat16)
logger.info(f"q: {q.shape}, k: {k.shape}, v: {v.shape}, dout: {dout.shape}")
torch.distributed.broadcast(q, src=0)
torch.distributed.broadcast(k, src=0)
torch.distributed.broadcast(v, src=0)
torch.distributed.broadcast(dout, src=0)
# Use EXTRACT_FUNC_DICT to shard the tensors
local_q = EXTRACT_FUNC_DICT[ring_impl_type](q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone()
local_k = EXTRACT_FUNC_DICT[ring_impl_type](k, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone()
local_v = EXTRACT_FUNC_DICT[ring_impl_type](v, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone()
local_q.requires_grad = True
local_k.requires_grad = True
local_v.requires_grad = True
usp_attn = LongContextAttention(ring_impl_type=ring_impl_type, attn_type=AttnType.FA)
if rank == 0:
print("#" * 30)
print("# ds-ulysses forward:")
print("#" * 30)
window_size = (-1, -1)
alibi_slopes, attn_bias = None, None
dropout_mask = None
logger.info(f"local_q: {local_q.shape}, local_k: {local_k.shape}, local_v: {local_v.shape}")
local_out = usp_attn(
local_q,
local_k,
local_v,
dropout_p=dropout_p,
causal=False,
window_size=window_size,
softcap=0.0,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
# extract local dout
local_dout = EXTRACT_FUNC_DICT[ring_impl_type](dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone()
if rank == 0:
print("#" * 30)
print("# ds-ulysses backward:")
print("#" * 30)
local_out.backward(local_dout)
torch.distributed.barrier()
if rank == 0:
print("#" * 30)
print("# local forward:")
print("#" * 30)
# reference, a local flash attn
out_ref, _, _ = flash_attn_func(
q,
k,
v,
dropout_p=dropout_p,
causal=False,
window_size=window_size,
softcap=0.0,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
if rank == 0:
print("#" * 30)
print("# local forward:")
print("#" * 30)
out_ref.backward(dout)
local_out_ref = EXTRACT_FUNC_DICT[ring_impl_type](out_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree)
if rank == 0:
print("local (rank) out", local_out)
print("out (distributed) - out_ref (non-distributed) diff", local_out_ref - local_out)
torch.testing.assert_close(local_out, local_out_ref, atol=1e-1, rtol=0)
@feifeibear @benihime91 Thank you very much for your information. I ran the script provided by @benihime91 , but I got the following error during backward. I suspect my environment configuration is incorrect. Could you please provide the required environment configuration to run the test code? This includes the torch version, cuda version, flashattn version, etc. Alternatively, I would be very grateful if you could provide a requirements.txt file for the environment.
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/mmv-video/rhyshen/workspace/DanceGRPO/tests/test_yunchang.py", line 174, in <module>
[rank0]: main(
[rank0]: File "/mnt/mmv-video/rhyshen/workspace/DanceGRPO/tests/test_yunchang.py", line 134, in main
[rank0]: local_out.backward(local_dout)
[rank0]: File "/mnt/mmv-video/rhyshen/env/diff/lib/python3.10/site-packages/torch/_tensor.py", line 626, in backward
[rank0]: torch.autograd.backward(
[rank0]: File "/mnt/mmv-video/rhyshen/env/diff/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]: _engine_run_backward(
[rank0]: File "/mnt/mmv-video/rhyshen/env/diff/lib/python3.10/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: RuntimeError: function RingFlashAttnFuncBackward returned an incorrect number of gradients (expected 14, got 13)
By the way, this is a reproducible error.