diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Fix meta tensor error with bitsandbytes quantization and device_map

Open arrdel opened this issue 1 month ago • 0 comments

What does this PR do?

Fixes #12719

This PR fixes a critical issue where using bitsandbytes quantization with device_map='balanced' (or other device_map strategies) on transformers models within diffusers pipelines results in a meta tensor error: NotImplementedError: Cannot copy out of meta tensor; no data!

Root Cause

When loading transformers models with both:

  • quantization_config (bitsandbytes 4-bit/8-bit)
  • device_map (especially 'balanced' for multi-GPU)

The combination of low_cpu_mem_usage=True (default) and device_map causes transformers to use meta tensors for memory-efficient loading. However, bitsandbytes quantization state objects cannot be materialized from meta device.

The error occurs because:

  1. With low_cpu_mem_usage=True and device_map, transformers uses meta tensors as placeholders
  2. During quantization, bitsandbytes creates quantization state (code, absmax tensors) on meta device
  3. When accelerate's AlignDevicesHook tries to move parameters to target devices via quant_state.to(device)
  4. The quantization state's tensors are still meta and cannot be copied/moved

Solution

Disable low_cpu_mem_usage when loading transformers models with bitsandbytes quantization (llm_int8, fp4, nf4) and device_map. This ensures tensors are materialized during loading rather than kept as meta placeholders, allowing quantization state to be properly moved to target devices.

Changes

  • Modified _load_sub_model in pipeline_loading_utils.py to detect bitsandbytes quantization + device_map combinations
  • Added logic to set low_cpu_mem_usage=False for these cases
  • Added informative logging when this workaround is applied
  • Added comprehensive documentation explaining the issue

Testing

This fix allows the exact code from issue #12719 to work correctly:

import diffusers, torch
qwen = diffusers.QwenImagePipeline.from_pretrained(
    'Qwen/Qwen-Image',
    quantization_config=diffusers.PipelineQuantizationConfig(
        quant_backend='bitsandbytes_4bit',
        quant_kwargs={'load_in_4bit':True, 'bnb_4bit_quant_type':'nf4', 'bnb_4bit_compute_dtype':torch.float16},
        components_to_quantize=['transformer', 'text_encoder']
    ),
    torch_dtype=torch.float16,
    device_map='balanced'
)

Impact

  • ✅ Enables multi-GPU quantized inference with device_map strategies
  • ✅ Maintains backward compatibility (only affects bitsandbytes + device_map case)
  • ✅ No performance regression for other quantization methods
  • ⚠️ Slightly higher memory usage during loading for affected cases (necessary tradeoff)

cc @yiyixuxu @DN6

arrdel avatar Dec 06 '25 02:12 arrdel