diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

fix torchao memory check

Open jiqing-feng opened this issue 2 months ago • 6 comments

The test pytest -rA tests/quantization/torchao/test_torchao.py::TorchAoTest::test_model_memory_usage failed with

>       assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio
E       assert (1416704 / 1382912) >= 2.0
tests/quantization/torchao/test_torchao.py:512: AssertionError                                                                       

on A100. I guess it is because the model is too small, most memories are consumed on cuda kernel launch instead of model weight. If we change it to a large model like black-forest-labs/FLUX.1-dev, the ratio will be 24244073472 / 12473665536 = 1.9436206143278139

@sayakpaul . Please review this PR. Thanks!

jiqing-feng avatar Oct 24 '25 05:10 jiqing-feng

For the black-forest-labs/FLUX.1-dev script, you can run

import torch
import os
from diffusers import FluxTransformer2DModel
from transformers import TorchAoConfig

torch.use_deterministic_algorithms(True)

def get_dummy_tensor_inputs(device=None, seed: int = 0):
    batch_size = 1
    num_latent_channels = 4096
    num_image_channels = 3
    height = width = 8
    sequence_length = 48
    embedding_dim = 768
    torch.manual_seed(seed)
    hidden_states = torch.randn((batch_size, num_latent_channels, height*width)).to(device, dtype=torch.bfloat16)
    torch.manual_seed(seed)
    encoder_hidden_states = torch.randn((batch_size, sequence_length, num_latent_channels)).to(
        device, dtype=torch.bfloat16
    )
    torch.manual_seed(seed)
    pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16)
    torch.manual_seed(seed)
    text_ids = torch.randn((num_latent_channels, num_image_channels)).to(device, dtype=torch.bfloat16)
    torch.manual_seed(seed)
    image_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16)
    timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size)
    return {
        "hidden_states": hidden_states,
        "encoder_hidden_states": encoder_hidden_states,
        "pooled_projections": pooled_prompt_embeds,
        "txt_ids": text_ids,
        "img_ids": image_ids,
        "timestep": timestep,
        "guidance": timestep * 3054,
    }

@torch.no_grad()
@torch.inference_mode()
def get_memory_consumption_stat(model, inputs):
    device_module.reset_peak_memory_stats()
    device_module.empty_cache()
    model(**inputs)
    max_mem_allocated = device_module.max_memory_allocated()
    return max_mem_allocated

torch_device = "xpu" if torch.xpu.is_available() else "cuda"
if torch_device == "cuda":
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
device_module = torch.xpu if torch.xpu.is_available() else torch.cuda
model_id = "black-forest-labs/FLUX.1-dev"
print(f"max allocated memory before loading: {device_module.max_memory_allocated()}")
inputs = get_dummy_tensor_inputs(device=torch_device)
print(f"max allocated memory after get inputs: {device_module.max_memory_allocated()}")
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", quantization_config=None, torch_dtype=torch.bfloat16).to(torch_device)
print(f"max allocated memory after get bf16 model: {device_module.max_memory_allocated()}")

with torch.no_grad(), torch.inference_mode():
    transformer(**inputs)
print(f"max allocated memory after bf16 model inference: {device_module.max_memory_allocated()}")

del transformer
device_module.reset_peak_memory_stats()
device_module.empty_cache()

transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", quantization_config=TorchAoConfig("int8_weight_only"), torch_dtype=torch.bfloat16).to(torch_device)
print(f"max allocated memory after get int8 model: {device_module.max_memory_allocated()}")

with torch.no_grad(), torch.inference_mode():
    transformer(**inputs)
print(f"max allocated memory after int8 model inference: {device_module.max_memory_allocated()}")

jiqing-feng avatar Oct 24 '25 05:10 jiqing-feng

I just ran the test on an H100 and it worked fine.

sayakpaul avatar Oct 24 '25 07:10 sayakpaul

I just ran the test on an H100 and it worked fine.

It seems like a device-related issue. Can we change it to a big model so other devices can also work? Or change the ratio as in this PR. We want to make the test pass on XPU and A100

jiqing-feng avatar Oct 24 '25 07:10 jiqing-feng

I just ran the test on an H100 and it worked fine.

It seems like a device-related issue. Can we change it to a big model so other devices can also work? Or change the ratio as in this PR. We want to make the test pass on XPU and A100

Hi @sayakpaul , could you share the ratio on H100?

jiqing-feng avatar Oct 24 '25 07:10 jiqing-feng

@sayakpaul @DN6 I'd like to jump into this case again, since we need to extend this case in xpu as well. I test this case in A100 as well. and I find it fail as well. I track the detail memory in A100

memory(B) bf16 int8wo
pure model 157184 123393
cublas workspace size 1179648 131072
pure forward 79872 79872

and since in the case, bf16 is running ahead of int8wo and cublas workspace is not delete explicitly by torch._C._cuda_clearCublasWorkspaces, which will lead the case fail since the memory like bf16 1416704(157184+1179648+79872) vs int8wo 1382913(123393+1179648+79872). 1416704/1382913<2.0, and I do not have H100 env, I guess it could pass in H100 mainly because cublas implementation is different?

and in xpu. there's no cublas workspace. the memory is like

memory(B) bf16 int8wo
pure model 158208 124416
pure forward 88064 88064

which is also reasonable. it's weight only quantization, memory increase in the pure forward should be same between bf16 and int8wo.

so in order to extend the case support in different hardware, could we adjust ratio, only measure the model only memory instead of model memory+runtime memory?

sywangyi avatar Dec 01 '25 06:12 sywangyi

We could adjust the ratio, depending on the hardware type.

sayakpaul avatar Dec 01 '25 13:12 sayakpaul

https://github.com/huggingface/diffusers/pull/12768, please help review, thanks

sywangyi avatar Dec 17 '25 07:12 sywangyi