flux-fp8-api icon indicating copy to clipboard operation
flux-fp8-api copied to clipboard

How to save a "prequantized_flow" safetensor?

Open smuelpeng opened this issue 1 year ago • 5 comments

Hello,

The documentation mentions that the --prequantized-flow option can be used to load a prequantized model, which reduces the checkpoint size by about 50% and shortens the startup time (default: False).

However, I couldn’t find any interface in the repository to enable this functionality. Could you please provide guidance on how to store and load a prequantized model to save resources and initialization time?

Looking forward to your response, thank you!

smuelpeng avatar Sep 09 '24 09:09 smuelpeng

Ah! Essentially it's just the checkpoint which gets created after loading the model and doing at least 12 steps of inference. You could do something like this in the root of the repo-

from flux_pipeline import FluxPipeline, ModelVersion
from safetensors.torch import save_file
prompt = "some prompt"
pipe = FluxPipeline.load_pipeline_from_config_path("./configs/your-config.json")
if pipe.config.version == ModelVersion.flux_schnell:
    for x in range(3):
        pipe.generate(prompt=prompt, num_steps=4)
else:
    pipe.generate(prompt=prompt, num_steps=12)

quantized_state_dict = pipe.model.state_dict()

save_file(quantized_state_dict, "some-model-prequantized.safetensors")

aredden avatar Sep 10 '24 00:09 aredden

Thank you for your helpful response. The solution works well for loading pre-quantized SFTs.

However, do you have any suggestions for saving and loading a Torch-compiled Flux model? Currently, the initialization time for compiling the Flux model is quite cumbersome, and I’m looking for ways to streamline this process.

smuelpeng avatar Sep 12 '24 09:09 smuelpeng

Ah- You can speed that up by using nightly torch- for me compilation only takes a few (maybe 3-4) seconds at most.

aredden avatar Sep 13 '24 16:09 aredden

I appreciate your amazing work!

for me torch-nightly takes 9-18 sec per inference on first 3 warm-up inferences and torch takes 1-1.5 minites per inference on first 3 inferences

am i missing something?

Muawizodux avatar Sep 24 '24 17:09 Muawizodux

That seems correct, it's possible that it's just related to the cpu- I have a 7950x so everything runs very fast.

aredden avatar Sep 25 '24 21:09 aredden