diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[core] Handle progress bar and logging in distributed environments

Open sayakpaul opened this issue 1 month ago • 1 comments

What does this PR do?

We should handle logging (including the progress bar stuff) gracefully when operating under distributed setups.

Before this PR:
Loading pipeline components...:   0%|                                                                                | 0/7 [00:00<?, ?it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...:  29%|████████████████████▌                                                   | 2/7 [00:00<00:00,  6.76it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 59.58it/s]
Loading pipeline components...:  57%|█████████████████████████████████████████▏                              | 4/7 [00:00<00:00,  9.45it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 46.15it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 60.59it/s]
Loading pipeline components...:  86%|█████████████████████████████████████████████████████████████▋          | 6/7 [00:00<00:00, 10.19it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 48.16it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.95it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.48it/s]
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:07<00:00,  6.64it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:07<00:00,  6.63it/s]
With this PR:
Loading pipeline components...:   0%|                                                                                | 0/7 [00:00<?, ?it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...:  14%|██████████▎                                                             | 1/7 [00:00<00:02,  2.34it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading pipeline components...:  43%|██████████████████████████████▊                                         | 3/7 [00:00<00:00,  5.90it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 61.48it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 62.15it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 47.45it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.62it/s]
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:07<00:00,  6.64it/s]

Notice that there's still:

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 61.48it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 62.15it/s]

It's coming from T5. We cannot control the T5 logging as it comes from transformers.

Test Code
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device

device = setup_distributed()

ulysses_degree = torch.distributed.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",  torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
    config=ContextParallelConfig(ulysses_degree=ulysses_degree)
)

prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""

generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]

if dist.get_rank() == 0:
    image.save("output_ulysses.png")
if dist.is_initialized():
    dist.destroy_process_group()

sayakpaul avatar Dec 08 '25 13:12 sayakpaul

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.