Cannot use diffusers pipeline for flux inference with full finetuning model
Using python flux_train.py \ --mixed_precision bf16 \ --pretrained_model_name_or_path /data/flux1/flux1-dev.safetensors \ --clip_l /data/openai/clip-vit-large-patch14/clip_l.safetensors \ --t5xxl /data/google/t5-v1_1-xxl/t5xxl_fp16.safetensors \ --ae /data/flux1/ae.safetensors \ --save_model_as safetensors \ --save_every_n_epochs 5 \ --sdpa \ --persistent_data_loader_workers \ --max_data_loader_n_workers 2 \ --seed 42 \ --gradient_checkpointing \ --mixed_precision bf16 \ --save_precision bf16 \ --dataset_config ./dataset_config/0715.toml \ --output_dir /data/flux_model_hand_finetuned/0715 \ --output_name sample \ --learning_rate 5e-5 \ --max_train_epochs 20 \ --sdpa \ --highvram \ --sample_prompts /data/hand_finetuning_dataset/sample_prompts.txt \ --sample_every_n_epoch 1 \ --cache_text_encoder_outputs_to_disk \ --cache_latents_to_disk \ --optimizer_type adafactor \ --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" \ --lr_scheduler constant_with_warmup \ --max_grad_norm 0.0 \ --timestep_sampling shift \ --discrete_flow_shift 3.1582 \ --model_prediction_type raw \ --guidance_scale 1.0 \ --fused_backward_pass \ --blocks_to_swap 8 \ --full_bf16
I got the sample.safetensor, but it fails to be loaded by diffusers pipeline, because of mismatching of some keys:
from diffusers import FluxPipeline, FluxTransformer2DModel
from safetensors.torch import load_file
finetuned_transformer = FluxTransformer2DModel( patch_size=1, in_channels=64, out_channels=64, num_layers=19, num_single_layers=38, attention_head_dim=128, num_attention_heads=24, joint_attention_dim=4096, pooled_projection_dim=768, guidance_embeds=True )
state_dict = load_file("/data/flux_model_hand_finetuned/0715/sample/sample.safetensors")
I got the error information:
Traceback (most recent call last): File "/home/zyf/repos/sd-scripts-sd3/inference_test_hands.py", line 27, in <module> finetuned_transformer.load_state_dict(state_dict) File "/home/zyf/miniconda3/envs/flux/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2584, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for FluxTransformer2DModel: Missing key(s) in state_dict: "time_text_embed.timestep_embedder.linear_1.weight", "time_text_embed.timestep_embedder.linear_1.bias", "time_text_embed.timestep_embedder.linear_2.weight", "time_text_embed.timestep_embedder.linear_2.bias", "time_text_embed.guidance_embedder.linear_1.weight", "time_text_embed.guidance_embedder.linear_1.bias", "time_text_embed.guidance_embedder.linear_2.weight", "time_text_embed.guidance_embedder.linear_2.bias", "time_text_embed.text_embedder.linear_1.weight", "time_text_embed.text_embedder.linear_1.bias", ...... Unexpected key(s) in state_dict: "double_blocks.0.img_attn.norm.key_norm.scale", "double_blocks.0.img_attn.norm.query_norm.scale", "double_blocks.0.img_attn.proj.bias", "double_blocks.0.img_attn.proj.weight", "double_blocks.0.img_attn.qkv.bias", "double_blocks.0.img_attn.qkv.weight", "double_blocks.0.img_mlp.0.bias",......