quanto icon indicating copy to clipboard operation
quanto copied to clipboard

Moving qint4 models takes a large amount of time

Open gabe56f opened this issue 1 year ago • 4 comments

Quantizing the new Flux.1 models and saving them to a file, and then loading them back in using (re)quantize takes a huge amount of time. (900+s, specs here) I've looked into the requantize function and thru basic debugging I have found that the .to(device) call takes the most time. (~90%) Just in case, this was a local problem, I went ahead and got a fresh install of linux with the minimal dependencies, to see if it was just a case of faulty drivers, but it wasn't.

The following is the script I'd like to run:

import json
import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from optimum.quanto import freeze, requantize, qint4, quantization_map, quantize
from safetensors.torch import load_file, save_file

dtype = torch.bfloat16
repo = "black-forest-labs/FLUX.1-schnell"
transformer = FluxTransformer2DModel.from_pretrained(repo, subfolder="transformer", torch_dtype=dtype)
quantize(transformer, weights=qint4)
freeze(transformer)
save_file(transformer, "model.safetensors")
with open("quantization_map.json", "w") as f:
    json.dump(quantization_map(transformer), f)
del transformer

with torch.device("meta"):
    with open("config.json", "r") as f:
        transformer = FluxTransformer2DModel.from_config(json.load(f)).to(dtype)

tdict = load_file("model.safetensors")
with open("quantization_map.json", "r") as f:
    tmap = json.load(f)

requantize(transformer, tdict, tmap, device="cuda")

gabe56f avatar Aug 06 '24 04:08 gabe56f

Hijacking this issue, this slowness also happens every time when you move the models between cpu and gpu on my end.

Edit: Slowness after the first move on my end is caused by the CPU using TinyGemmQBitsTensor and AMD / HIP using QBitsTensor. Forcing the CPU to use QBitsTensor as well eliminates the slowness.

Disty0 avatar Aug 13 '24 21:08 Disty0

@Disty0 is correct: when moving models between devices, the weights might be reformatted as expected by the optimized kernel used on that device. This happens in QBitsTensor.create(). Here, the weights are first loaded on CPU, and formatted for the TinyGemm pytorch kernels. The takes time, and is pretty useless because eventually they are put back in their original format when moving to a CUDA device. I could add a context to disable kernel optimizations: something like with quanto.disable_optimizations():, and wrap the weights loading with that context in requantize, so that the formatting only happens when moving to the target device.

dacorvo avatar Aug 14 '24 12:08 dacorvo

That is (imho) probably the best course of action, that's exactly what I ended up doing with torchao in the meanwhile

gabe56f avatar Aug 14 '24 12:08 gabe56f

The following locally stored quantized objects load 100% faster than quantizing them on the fly: #1 Do this once: Quantize the transformer and TF encoder in Flux:

from diffusers import FluxTransformer2DModel
from transformers import T5EncoderModel
from optimum.quanto import freeze, qint8, quantize, quantization_map
from pathlib import Path
import json

base_model = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_pretrained(
    base_model,
    subfolder= "transformer" ,
    torch_dtype=dtype
)
quantize(transformer, weights=qint8)
freeze(transformer)

save_directory = "./flux-dev/fluxtransformer2dmodel_qint8"
transformer.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json" )
qmap = quantization_map(transformer)
with  open (qmap_name, "w" , encoding= "utf8" ) as f:
    json.dump(qmap, f, indent= 4 )
print('Transformer done')
text_encoder_2 = T5EncoderModel.from_pretrained(
    base_model,
    subfolder= "text_encoder_2" ,
    torch_dtype=dtype
)
quantize(text_encoder_2, weights=qint8)
freeze(text_encoder_2)

save_directory = "./flux-dev/t5encodermodel_qint8"
text_encoder_2.save_pretrained(save_directory)
qmap_name = Path(save_directory, "quanto_qmap.json" )
qmap = quantization_map(text_encoder_2)
with  open (qmap_name, "w" , encoding= "utf8" ) as f:
    json.dump(qmap, f, indent= 4 )
print('T5encoder done')

Now you can refer back to this doing your inference:

    repo_id = "black-forest-labs/FLUX.1-dev" 

    from transformers import T5EncoderModel
    from optimum.quanto import QuantizedTransformersModel, QuantizedDiffusersModel
    start = time.perf_counter()
    dtype = torch.bfloat16

    class  QuantizedFluxTransformer2DModel (QuantizedDiffusersModel):
        base_class = FluxTransformer2DModel

    transformer = QuantizedFluxTransformer2DModel.from_pretrained(
        "./flux-dev/fluxtransformer2dmodel_qint8"
    ).to(dtype=dtype)

    class  QuantizedT5EncoderModelForCausalLM (QuantizedTransformersModel):
        auto_class = T5EncoderModel
        auto_class.from_config = auto_class._from_config

    text_encoder_2 = QuantizedT5EncoderModelForCausalLM.from_pretrained(
        "./flux-dev/t5encodermodel_qint8"
    ).to(dtype=dtype)

    pipe = FluxPipeline.from_pretrained(repo_id, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype)

    if device == 'cpu':
        pipe.enable_model_cpu_offload()
    else:
        pipe = pipe.to(device) #'cuda'

ukaprch avatar Aug 19 '24 23:08 ukaprch

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] avatar Sep 19 '24 02:09 github-actions[bot]

This issue was closed because it has been stalled for 5 days with no activity.

github-actions[bot] avatar Sep 24 '24 02:09 github-actions[bot]