Potential LoRA performance issue
Thought I'd raise this in case there's an issue.
Steps to reproduce:
1 - create FLUX.1 LoRA fine-tune at Replicate 2 - generate images using Replicate's FLUX.1[dev] API using the FLUX.1 LoRA 3 - result: perfect face matching to LoRA training 3 - load same LoRA using LoRA load techniques in README.md for flux-fp8-api 4 - generate image with same prompt and parameters using flux-fp8-api 5 - result: very poor face matching to LoRA training
Any suggestions for what I might try?
Here's how I implemented LoRA loading in main.py:
--- main.orig.py 2024-08-29 15:34:42.612578339 +0200
+++ main.py 2024-08-29 14:53:52.603088816 +0200
@@ -2,9 +2,9 @@
import uvicorn
from api import app
-
def parse_args():
parser = argparse.ArgumentParser(description="Launch Flux API server")
+ # Existing arguments...
parser.add_argument(
"-c",
"--config-path",
@@ -145,9 +145,17 @@
dest="quantize_flow_embedder_layers",
help="Quantize the flow embedder layers in the flow model, saves ~512MB vram usage, but precision loss is very noticeable",
)
+
+ # New arguments for LoRA loading
+ parser.add_argument(
+ "-L", "--lora-paths", type=str, help="Comma-separated paths to LoRA checkpoint files"
+ )
+ parser.add_argument(
+ "-S", "--lora-scales", type=str, default="1.0", help="Comma-separated scales for each LoRA"
+ )
+
return parser.parse_args()
-
def main():
args = parse_args()
@@ -192,8 +200,16 @@
)
app.state.model = FluxPipeline.load_pipeline_from_config(config)
- uvicorn.run(app, host=args.host, port=args.port)
+ # If LoRA paths are provided, apply them sequentially
+ if args.lora_paths:
+ lora_paths = args.lora_paths.split(',')
+ lora_scales = [float(scale) for scale in args.lora_scales.split(',')] if args.lora_scales else [1.0] * len(lora_paths)
+
+ # Apply each LoRA sequentially
+ for lora_path, scale in zip(lora_paths, lora_scales):
+ app.state.model.load_lora(lora_path, scale=scale)
+ uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()
And I call main.py like this:
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python main.py --config-path configs/config-dev-offload-1-4080.json --port 7888 --host <IP> --lora-scale 1 --lora-path /root/flux-fp8-api.working/models/lora.safetensors
Ah that's interesting. Are you using the latest code? There was a bug earlier where it was always setting the lora alpha to 1.0 for huggingface diffusers loras. Though it could be something else.
Thanks for the reply! Yes, I'm using the latest code.
It's possible that there are some lora loading specifics that I didn't implement well- but I'm not entire sure what that would be. I will have to look into other lora implementations.
Thanks for your amazing work. I also had some issues with the Lora.
- On 4090 The loras load correctly and work but I was getting different and weaker effect than the diffusers and I had to increase the weight to around 2.0 to get beeter effects but I noticed prompt adherence quality decreased after using bigger lora weights
- On H100 the loras load without error but I kept getting black images with txt2img. I was able to generate with init_image but the same lora weight issue was present here too.
Here's the config I used for 4090:
{
"version": "flux-dev",
"params": {
"in_channels": 64,
"vec_in_dim": 768,
"context_in_dim": 4096,
"hidden_size": 3072,
"mlp_ratio": 4.0,
"num_heads": 24,
"depth": 19,
"depth_single_blocks": 38,
"axes_dim": [
16,
56,
56
],
"theta": 10000,
"qkv_bias": true,
"guidance_embed": true
},
"ae_params": {
"resolution": 256,
"in_channels": 3,
"ch": 128,
"out_ch": 3,
"ch_mult": [
1,
2,
4,
4
],
"num_res_blocks": 2,
"z_channels": 16,
"scale_factor": 0.3611,
"shift_factor": 0.1159
},
"ckpt_path": "flux1-dev.safetensors",
"ae_path": "ae.safetensors",
"repo_id": "black-forest-labs/FLUX.1-dev",
"repo_flow": "flux1-dev.safetensors",
"repo_ae": "ae.safetensors",
"text_enc_max_length": 512,
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
"text_enc_device": "cuda:0",
"ae_device": "cuda:0",
"flux_device": "cuda:0",
"flow_dtype": "float16",
"ae_dtype": "bfloat16",
"text_enc_dtype": "bfloat16",
"flow_quantization_dtype": "qfloat8",
"text_enc_quantization_dtype": "qint4",
"ae_quantization_dtype": "qfloat8",
"compile_extras": true,
"compile_blocks": true,
"offload_text_encoder": true,
"offload_vae": true,
"offload_flow": false
}
config I used for H100:
{
"version": "flux-dev",
"params": {
"in_channels": 64,
"vec_in_dim": 768,
"context_in_dim": 4096,
"hidden_size": 3072,
"mlp_ratio": 4.0,
"num_heads": 24,
"depth": 19,
"depth_single_blocks": 38,
"axes_dim": [
16,
56,
56
],
"theta": 10000,
"qkv_bias": true,
"guidance_embed": true
},
"ae_params": {
"resolution": 256,
"in_channels": 3,
"ch": 128,
"out_ch": 3,
"ch_mult": [
1,
2,
4,
4
],
"num_res_blocks": 2,
"z_channels": 16,
"scale_factor": 0.3611,
"shift_factor": 0.1159
},
"ckpt_path": "flux1-dev.safetensors",
"ae_path": "ae.safetensors",
"repo_id": "black-forest-labs/FLUX.1-dev",
"repo_flow": "flux1-dev.safetensors",
"repo_ae": "ae.safetensors",
"text_enc_max_length": 512,
"text_enc_path": "city96/t5-v1_1-xxl-encoder-bf16",
"text_enc_device": "cuda:0",
"ae_device": "cuda:0",
"flux_device": "cuda:0",
"flow_dtype": "float16",
"ae_dtype": "bfloat16",
"text_enc_dtype": "bfloat16",
"flow_quantization_dtype": "qfloat8",
"text_enc_quantization_dtype": "qint4",
"ae_quantization_dtype": "qfloat8",
"compile_extras": true,
"compile_blocks": true,
"offload_text_encoder": false,
"offload_vae": false,
"offload_flow": false
}
Hope it's helpful, thanks.
If you're getting black images I would recommend setting flow_dtype to bfloat16, it should help a bit. I'm still a bit unsure as to how I am supposed to handle lora alphas when it's not given in a lora's state dict, since I believe different trainers use different values and I have no idea which is which by default haha.. Sorry 😢
How can I help? Maybe I can't, but thought I'd offer.
Thanks 😄 - well if you find anywhere in my lora loading implementation here https://github.com/aredden/flux-fp8-api/blob/main/lora_loading.py let me know and I'll change it, or you can submit a pull request and I'll look it over. Up to you 😄