long-context-attention icon indicating copy to clipboard operation
long-context-attention copied to clipboard

usage with pytorch FSDP

Open benihime91 opened this issue 8 months ago • 3 comments

How to use with pytorch FSDP for 2D parallelism training. For example we want to apply USP attention on single nodes and FSDP across multiple nodes

benihime91 avatar Apr 06 '25 12:04 benihime91

The parallel group for USP and FDSP should be the same. You can wrap the USP applied module with FSDP.

feifeibear avatar Apr 07 '25 03:04 feifeibear

I've successfully implemented and tested the PyTorch Device Mesh API for hybrid parallelism. Here's my working implementation that demonstrates the device mesh setup with data parallelism (DP), ring parallelism, and Ulysses parallelism

  • The current implementation uses DeviceMesh to create a 3D parallel structure
  • Successfully tested with Flash Attention and validated outputs against non-distributed reference implementation
  • The mesh structure is configured as (dp_degree, sp_ring_degree, sp_ulysses_degree) with corresponding process groups
  • It seems to be working for --sp_ulysses_degree 2 --sp_ring_degree 4 --dp_degree 1 --use_ulysses_low --ring_impl_type "basic"

I'm now looking to integrate FSDP (Fully Sharded Data Parallel) with this setup. My understanding is that FSDP should be applied over the data parallel group (dp_group) created by the device mesh. Is this the correct approach?

import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, Optional

import torch
from fastcore.script import call_parse
from flash_attn import flash_attn_func
from loguru import logger
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from yunchang import EXTRACT_FUNC_DICT, AsyncLongContextAttention, LongContextAttention, set_seq_parallel_pg
from yunchang.globals import PROCESS_GROUP
from yunchang.kernels import AttnType


@dataclass
class ParallelDims:
    sp_ulysses_degree: int = 1
    sp_ring_degree: int = 1
    dp_degree: int = 1

    def build_device_mesh(self, device_type: str = "cuda", use_ulysses_low: bool = True) -> DeviceMesh:
        if use_ulysses_low:
            mesh_shape = (self.dp_degree, self.sp_ring_degree, self.sp_ulysses_degree)
            mesh_dim_names = ("dp", "ring", "ulysses")
        else:
            mesh_shape = (self.dp_degree, self.sp_ulysses_degree, self.sp_ring_degree)
            mesh_dim_names = ("dp", "ulysses", "ring")
        return init_device_mesh(device_type=device_type, mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names)


@call_parse
def main(
    sp_ulysses_degree: int = 1,
    sp_ring_degree: int = 1,
    dp_degree: int = 1,
    use_ulysses_low: bool = True,
    ring_impl_type: str = "basic",
):
    world_size = int(os.environ["WORLD_SIZE"])
    logger.info(f"world_size: {world_size}")
    torch.distributed.init_process_group(backend="nccl")

    rank = torch.distributed.get_rank()

    logger.remove()
    logger.add(sys.stdout, level="INFO", format=f"[ rank={rank} ]" + "{time} {level} {message}", colorize=True)

    parallel_dims = ParallelDims(
        sp_ulysses_degree=sp_ulysses_degree,
        sp_ring_degree=sp_ring_degree,
        dp_degree=dp_degree,
    )
    device_type = "cuda" if torch.cuda.is_available() else "cpu"
    world_mesh = parallel_dims.build_device_mesh(device_type=device_type, use_ulysses_low=use_ulysses_low)

    dp_mesh = world_mesh["dp"]
    dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
    logger.info(f"dp_degree: {dp_degree}, dp_rank: {dp_rank}")

    ring_mesh = world_mesh["ring"]
    ring_degree, ring_rank = ring_mesh.size(), ring_mesh.get_local_rank()
    logger.info(f"ring_degree: {ring_degree}, ring_rank: {ring_rank}")

    ulysses_mesh = world_mesh["ulysses"]
    ulysses_degree, ulysses_rank = ulysses_mesh.size(), ulysses_mesh.get_local_rank()
    logger.info(f"ulysses_degree: {ulysses_degree}, ulysses_rank: {ulysses_rank}")

    PROCESS_GROUP.ULYSSES_PG = world_mesh["ulysses"].get_group()
    PROCESS_GROUP.RING_PG = world_mesh["ring"].get_group()

    batch_size = 1
    seqlen = 4096
    nheads = 12
    d = 128
    dropout_p = 0
    deterministic = False

    q = torch.randn(batch_size, seqlen, nheads, d, device=torch.device(f"cuda:{rank}"), requires_grad=True, dtype=torch.bfloat16)
    k = torch.randn(batch_size, seqlen, nheads, d, device=torch.device(f"cuda:{rank}"), requires_grad=True, dtype=torch.bfloat16)
    v = torch.randn(batch_size, seqlen, nheads, d, device=torch.device(f"cuda:{rank}"), requires_grad=True, dtype=torch.bfloat16)
    dout = torch.randn(batch_size, seqlen, nheads, d, device=torch.device(f"cuda:{rank}"), requires_grad=True, dtype=torch.bfloat16)

    logger.info(f"q: {q.shape}, k: {k.shape}, v: {v.shape}, dout: {dout.shape}")

    torch.distributed.broadcast(q, src=0)
    torch.distributed.broadcast(k, src=0)
    torch.distributed.broadcast(v, src=0)
    torch.distributed.broadcast(dout, src=0)

    # Use EXTRACT_FUNC_DICT to shard the tensors
    local_q = EXTRACT_FUNC_DICT[ring_impl_type](q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone()
    local_k = EXTRACT_FUNC_DICT[ring_impl_type](k, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone()
    local_v = EXTRACT_FUNC_DICT[ring_impl_type](v, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone()

    local_q.requires_grad = True
    local_k.requires_grad = True
    local_v.requires_grad = True

    usp_attn = LongContextAttention(ring_impl_type=ring_impl_type, attn_type=AttnType.FA)

    if rank == 0:
        print("#" * 30)
        print("# ds-ulysses forward:")
        print("#" * 30)

    window_size = (-1, -1)
    alibi_slopes, attn_bias = None, None
    dropout_mask = None

    logger.info(f"local_q: {local_q.shape}, local_k: {local_k.shape}, local_v: {local_v.shape}")

    local_out = usp_attn(
        local_q,
        local_k,
        local_v,
        dropout_p=dropout_p,
        causal=False,
        window_size=window_size,
        softcap=0.0,
        alibi_slopes=alibi_slopes,
        deterministic=deterministic,
        return_attn_probs=True,
    )

    # extract local dout
    local_dout = EXTRACT_FUNC_DICT[ring_impl_type](dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree).detach().clone()

    if rank == 0:
        print("#" * 30)
        print("# ds-ulysses backward:")
        print("#" * 30)

    local_out.backward(local_dout)

    torch.distributed.barrier()

    if rank == 0:
        print("#" * 30)
        print("# local forward:")
        print("#" * 30)
    # reference, a local flash attn
    out_ref, _, _ = flash_attn_func(
        q,
        k,
        v,
        dropout_p=dropout_p,
        causal=False,
        window_size=window_size,
        softcap=0.0,
        alibi_slopes=alibi_slopes,
        deterministic=deterministic,
        return_attn_probs=True,
    )

    if rank == 0:
        print("#" * 30)
        print("# local forward:")
        print("#" * 30)

    out_ref.backward(dout)

    local_out_ref = EXTRACT_FUNC_DICT[ring_impl_type](out_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree)

    if rank == 0:
        print("local (rank) out", local_out)
        print("out (distributed) - out_ref (non-distributed) diff", local_out_ref - local_out)

    torch.testing.assert_close(local_out, local_out_ref, atol=1e-1, rtol=0)

benihime91 avatar Apr 07 '25 09:04 benihime91

@feifeibear @benihime91 Thank you very much for your information. I ran the script provided by @benihime91 , but I got the following error during backward. I suspect my environment configuration is incorrect. Could you please provide the required environment configuration to run the test code? This includes the torch version, cuda version, flashattn version, etc. Alternatively, I would be very grateful if you could provide a requirements.txt file for the environment.

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/mmv-video/rhyshen/workspace/DanceGRPO/tests/test_yunchang.py", line 174, in <module>
[rank0]:     main(
[rank0]:   File "/mnt/mmv-video/rhyshen/workspace/DanceGRPO/tests/test_yunchang.py", line 134, in main
[rank0]:     local_out.backward(local_dout)
[rank0]:   File "/mnt/mmv-video/rhyshen/env/diff/lib/python3.10/site-packages/torch/_tensor.py", line 626, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/mnt/mmv-video/rhyshen/env/diff/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/mnt/mmv-video/rhyshen/env/diff/lib/python3.10/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]: RuntimeError: function RingFlashAttnFuncBackward returned an incorrect number of gradients (expected 14, got 13)

By the way, this is a reproducible error.

dutsc avatar Sep 26 '25 08:09 dutsc