diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[core] Enable CP for kernels-based attention backends

Open sayakpaul opened this issue 1 month ago • 1 comments

What does this PR do?

Adds CP support to the kernels-based attention backends.

Our CP support is quickly gaining traction. Currently, we have a few attention backends that are fully based on kernels. In order for their adoption to grow and make them a bit more complete in terms of feature parity, I think we should make them CP-compatible, too.

Code to test:
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig, AutoModel


CKPT_ID = "black-forest-labs/FLUX.1-dev"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cp-backend",
        type=str,
        choices=["ring", "ulysses", "unified"],
        default="ulysses",
        help="Context parallel backend to use.",
    )
    parser.add_argument(
        "--attn-backend",
        type=str,
        choices=["flash_hub", "_flash_3_hub", "sage_hub"],
        default="flash_hub",
        help="Attention backend to use.",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device


def main():
    args = parse_args()

    device = setup_distributed()
    world_size = dist.get_world_size()
    if args.cp_backend == "ring":
        cp_config = ContextParallelConfig(ring_degree=world_size)
    elif args.cp_backend == "unified":
        cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
    else:
        cp_config = ContextParallelConfig(ulysses_degree=world_size)

    transformer = AutoModel.from_pretrained(
        CKPT_ID, 
        subfolder="transformer", 
        torch_dtype=torch.bfloat16, 
        parallel_config=cp_config
    )

    pipeline = DiffusionPipeline.from_pretrained(
        CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
    ).to(device)
    pipeline.transformer.set_attention_backend(args.attn_backend)

    prompt = """
    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
    """

    generator = torch.Generator().manual_seed(42)
    image = pipeline(
        prompt,
        guidance_scale=3.5,
        num_inference_steps=50,
        generator=generator,
    ).images[0]

    if dist.get_rank() == 0:
        image.save(f"output_{args.cp_backend}_{args.attn_backend}.png")

    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

Outputs:

FA2+ Ulysses FA3 + Ulysses SAGE + Ulysses
Ring Ulysses Unified

sayakpaul avatar Dec 09 '25 09:12 sayakpaul

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.