diffusers
diffusers copied to clipboard
[core] Enable CP for kernels-based attention backends
What does this PR do?
Adds CP support to the kernels-based attention backends.
Our CP support is quickly gaining traction. Currently, we have a few attention backends that are fully based on kernels. In order for their adoption to grow and make them a bit more complete in terms of feature parity, I think we should make them CP-compatible, too.
Code to test:
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig, AutoModel
CKPT_ID = "black-forest-labs/FLUX.1-dev"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cp-backend",
type=str,
choices=["ring", "ulysses", "unified"],
default="ulysses",
help="Context parallel backend to use.",
)
parser.add_argument(
"--attn-backend",
type=str,
choices=["flash_hub", "_flash_3_hub", "sage_hub"],
default="flash_hub",
help="Attention backend to use.",
)
return parser.parse_args()
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
return device
def main():
args = parse_args()
device = setup_distributed()
world_size = dist.get_world_size()
if args.cp_backend == "ring":
cp_config = ContextParallelConfig(ring_degree=world_size)
elif args.cp_backend == "unified":
cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
else:
cp_config = ContextParallelConfig(ulysses_degree=world_size)
transformer = AutoModel.from_pretrained(
CKPT_ID,
subfolder="transformer",
torch_dtype=torch.bfloat16,
parallel_config=cp_config
)
pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend(args.attn_backend)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
generator = torch.Generator().manual_seed(42)
image = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=50,
generator=generator,
).images[0]
if dist.get_rank() == 0:
image.save(f"output_{args.cp_backend}_{args.attn_backend}.png")
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()
Outputs:
| FA2+ Ulysses | FA3 + Ulysses | SAGE + Ulysses |
|---|---|---|
|
|
|
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.