DiffSynth-Studio
DiffSynth-Studio copied to clipboard
Flux full train的模型转opensource format
Flux full train的模型在comfyui中无法加载,利用diffusers的脚本也无法转成diffusers格式,这个是否能支持
@justfuwei 暂不支持
Flux full train的模型在comfyui中无法加载,利用diffusers的脚本也无法转成diffusers格式,这个是否能支持
Maybe you could try this script for Flux dev? Please let me know if this script works for you.
from diffsynth import load_state_dict
from safetensors.torch import save_file
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_diffusers_to_flux(
original_state_dict, num_layers, num_single_layers
):
converted_state_dict = {}
## time_embedder.timestep_embedder -> time_in
converted_state_dict["time_in.in_layer.weight"] = original_state_dict.pop(
"time_embedder.timestep_embedder.0.weight"
)
converted_state_dict[ "time_in.in_layer.bias" ] = original_state_dict.pop(
"time_embedder.timestep_embedder.0.bias"
)
converted_state_dict["time_in.out_layer.weight"] = original_state_dict.pop(
"time_embedder.timestep_embedder.2.weight"
)
converted_state_dict[ "time_in.out_layer.bias" ] = original_state_dict.pop(
"time_embedder.timestep_embedder.2.bias"
)
## pooled_text_embedder -> vector_in
converted_state_dict["vector_in.in_layer.weight"] = original_state_dict.pop(
"pooled_text_embedder.0.weight"
)
converted_state_dict["vector_in.in_layer.bias"] = original_state_dict.pop(
"pooled_text_embedder.0.bias"
)
converted_state_dict[ "vector_in.out_layer.weight"] = original_state_dict.pop(
"pooled_text_embedder.2.weight"
)
converted_state_dict["vector_in.out_layer.bias"] = original_state_dict.pop(
"pooled_text_embedder.2.bias"
)
# guidance
has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance:
converted_state_dict["guidance_in.in_layer.weight"] = original_state_dict.pop(
"guidance_embedder.timestep_embedder.0.weight"
)
converted_state_dict["guidance_in.in_layer.bias"] = original_state_dict.pop(
"guidance_embedder.timestep_embedder.0.bias"
)
converted_state_dict["guidance_in.out_layer.weight"] = original_state_dict.pop(
"guidance_embedder.timestep_embedder.2.weight"
)
converted_state_dict["guidance_in.out_layer.bias"] = original_state_dict.pop(
"guidance_embedder.timestep_embedder.2.bias"
)
# context_embedder
converted_state_dict["txt_in.weight"] = original_state_dict.pop("context_embedder.weight")
converted_state_dict["txt_in.bias"] = original_state_dict.pop("context_embedder.bias")
# x_embedder
converted_state_dict["img_in.weight"] = original_state_dict.pop("x_embedder.weight")
converted_state_dict["img_in.bias"] = original_state_dict.pop("x_embedder.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"blocks.{i}."
# norms.
## norm1_a
converted_state_dict[f"double_blocks.{i}.img_mod.lin.weight"] = original_state_dict.pop(
f"{block_prefix}norm1_a.linear.weight"
)
converted_state_dict[f"double_blocks.{i}.img_mod.lin.bias"] = original_state_dict.pop(
f"{block_prefix}norm1_a.linear.bias"
)
## norm1_b
converted_state_dict[f"double_blocks.{i}.txt_mod.lin.weight"] = original_state_dict.pop(
f"{block_prefix}norm1_b.linear.weight"
)
converted_state_dict[f"double_blocks.{i}.txt_mod.lin.bias"] = original_state_dict.pop(
f"{block_prefix}norm1_b.linear.bias"
)
# qkv
converted_state_dict[f"double_blocks.{i}.img_attn.qkv.weight"] = original_state_dict.pop(
f"{block_prefix}attn.a_to_qkv.weight"
)
converted_state_dict[f"double_blocks.{i}.img_attn.qkv.bias"] = original_state_dict.pop(
f"{block_prefix}attn.a_to_qkv.bias"
)
converted_state_dict[f"double_blocks.{i}.txt_attn.qkv.weight"] = original_state_dict.pop(
f"{block_prefix}attn.b_to_qkv.weight"
)
converted_state_dict[f"double_blocks.{i}.txt_attn.qkv.bias"] = original_state_dict.pop(
f"{block_prefix}attn.b_to_qkv.bias"
)
# qk_norm
converted_state_dict[f"double_blocks.{i}.img_attn.norm.query_norm.scale"] = original_state_dict.pop(
f"{block_prefix}attn.norm_q_a.weight"
)
converted_state_dict[f"double_blocks.{i}.img_attn.norm.key_norm.scale"] = original_state_dict.pop(
f"{block_prefix}attn.norm_k_a.weight"
)
converted_state_dict[f"double_blocks.{i}.txt_attn.norm.query_norm.scale"] = original_state_dict.pop(
f"{block_prefix}attn.norm_q_b.weight"
)
converted_state_dict[f"double_blocks.{i}.txt_attn.norm.key_norm.scale"] = original_state_dict.pop(
f"{block_prefix}attn.norm_k_b.weight"
)
# ff
converted_state_dict[f"double_blocks.{i}.img_mlp.0.weight"] = original_state_dict.pop(
f"{block_prefix}ff_a.0.weight"
)
converted_state_dict[f"double_blocks.{i}.img_mlp.0.bias"] = original_state_dict.pop(
f"{block_prefix}ff_a.0.bias"
)
converted_state_dict[f"double_blocks.{i}.img_mlp.2.weight"] = original_state_dict.pop(
f"{block_prefix}ff_a.2.weight"
)
converted_state_dict[f"double_blocks.{i}.img_mlp.2.bias"] = original_state_dict.pop(
f"{block_prefix}ff_a.2.bias"
)
converted_state_dict[f"double_blocks.{i}.txt_mlp.0.weight"] = original_state_dict.pop(
f"{block_prefix}ff_b.0.weight"
)
converted_state_dict[f"double_blocks.{i}.txt_mlp.0.bias"] = original_state_dict.pop(
f"{block_prefix}ff_b.0.bias"
)
converted_state_dict[f"double_blocks.{i}.txt_mlp.2.weight"] = original_state_dict.pop(
f"{block_prefix}ff_b.2.weight"
)
converted_state_dict[f"double_blocks.{i}.txt_mlp.2.bias"] = original_state_dict.pop(
f"{block_prefix}ff_b.2.bias"
)
# output projections.
converted_state_dict[f"double_blocks.{i}.img_attn.proj.weight"] = original_state_dict.pop(
f"{block_prefix}attn.a_to_out.weight"
)
converted_state_dict[f"double_blocks.{i}.img_attn.proj.bias"] = original_state_dict.pop(
f"{block_prefix}attn.a_to_out.bias"
)
converted_state_dict[f"double_blocks.{i}.txt_attn.proj.weight"] = original_state_dict.pop(
f"{block_prefix}attn.b_to_out.weight"
)
converted_state_dict[f"double_blocks.{i}.txt_attn.proj.bias"] = original_state_dict.pop(
f"{block_prefix}attn.b_to_out.bias"
)
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"single_blocks.{i}.modulation.lin.weight"] = original_state_dict.pop(
f"{block_prefix}norm.linear.weight"
)
converted_state_dict[f"single_blocks.{i}.modulation.lin.bias"] = original_state_dict.pop(
f"{block_prefix}norm.linear.bias"
)
# qkv
converted_state_dict[f"single_blocks.{i}.linear1.weight"] = original_state_dict.pop(
f"{block_prefix}to_qkv_mlp.weight"
)
converted_state_dict[f"single_blocks.{i}.linear1.bias"] = original_state_dict.pop(
f"{block_prefix}to_qkv_mlp.bias"
)
# qk norm
converted_state_dict[ f"single_blocks.{i}.norm.query_norm.scale"] = original_state_dict.pop(
f"{block_prefix}norm_q_a.weight"
)
converted_state_dict[f"single_blocks.{i}.norm.key_norm.scale"] = original_state_dict.pop(
f"{block_prefix}norm_k_a.weight"
)
# output projections.
converted_state_dict[f"single_blocks.{i}.linear2.weight"] = original_state_dict.pop(
f"{block_prefix}proj_out.weight"
)
converted_state_dict[f"single_blocks.{i}.linear2.bias"] = original_state_dict.pop(
f"{block_prefix}proj_out.bias"
)
converted_state_dict["final_layer.linear.weight"] = original_state_dict.pop("final_proj_out.weight")
converted_state_dict["final_layer.linear.bias"] = original_state_dict.pop("final_proj_out.bias")
converted_state_dict["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(original_state_dict.pop("final_norm_out.linear.weight"))
converted_state_dict["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(original_state_dict.pop("final_norm_out.linear.bias"))
#
# converted_state_dict["final_norm_out.linear.weight"] = swap_scale_shift(
# original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
# )
# converted_state_dict["final_norm_out.linear.bias"] = swap_scale_shift(
# original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
# )
return converted_state_dict
def main():
state_dict = load_state_dict("xxxx/step-40000.safetensors")
convert_state_dict = convert_diffusers_to_flux(state_dict, 19, 38)
save_file(convert_state_dict, "xxx/step-40000-convert.safetensors")
if __name__ == "__main__":
main()
Flux full train的模型在comfyui中无法加载,利用diffusers的脚本也无法转成diffusers格式,这个是否能支持
Maybe you could try this script for Flux dev? Please let me know if this script works for you.
from diffsynth import load_state_dict from safetensors.torch import save_file def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) return new_weight def convert_diffusers_to_flux( original_state_dict, num_layers, num_single_layers ): converted_state_dict = {} ## time_embedder.timestep_embedder -> time_in converted_state_dict["time_in.in_layer.weight"] = original_state_dict.pop( "time_embedder.timestep_embedder.0.weight" ) converted_state_dict[ "time_in.in_layer.bias" ] = original_state_dict.pop( "time_embedder.timestep_embedder.0.bias" ) converted_state_dict["time_in.out_layer.weight"] = original_state_dict.pop( "time_embedder.timestep_embedder.2.weight" ) converted_state_dict[ "time_in.out_layer.bias" ] = original_state_dict.pop( "time_embedder.timestep_embedder.2.bias" ) ## pooled_text_embedder -> vector_in converted_state_dict["vector_in.in_layer.weight"] = original_state_dict.pop( "pooled_text_embedder.0.weight" ) converted_state_dict["vector_in.in_layer.bias"] = original_state_dict.pop( "pooled_text_embedder.0.bias" ) converted_state_dict[ "vector_in.out_layer.weight"] = original_state_dict.pop( "pooled_text_embedder.2.weight" ) converted_state_dict["vector_in.out_layer.bias"] = original_state_dict.pop( "pooled_text_embedder.2.bias" ) # guidance has_guidance = any("guidance" in k for k in original_state_dict) if has_guidance: converted_state_dict["guidance_in.in_layer.weight"] = original_state_dict.pop( "guidance_embedder.timestep_embedder.0.weight" ) converted_state_dict["guidance_in.in_layer.bias"] = original_state_dict.pop( "guidance_embedder.timestep_embedder.0.bias" ) converted_state_dict["guidance_in.out_layer.weight"] = original_state_dict.pop( "guidance_embedder.timestep_embedder.2.weight" ) converted_state_dict["guidance_in.out_layer.bias"] = original_state_dict.pop( "guidance_embedder.timestep_embedder.2.bias" ) # context_embedder converted_state_dict["txt_in.weight"] = original_state_dict.pop("context_embedder.weight") converted_state_dict["txt_in.bias"] = original_state_dict.pop("context_embedder.bias") # x_embedder converted_state_dict["img_in.weight"] = original_state_dict.pop("x_embedder.weight") converted_state_dict["img_in.bias"] = original_state_dict.pop("x_embedder.bias") # double transformer blocks for i in range(num_layers): block_prefix = f"blocks.{i}." # norms. ## norm1_a converted_state_dict[f"double_blocks.{i}.img_mod.lin.weight"] = original_state_dict.pop( f"{block_prefix}norm1_a.linear.weight" ) converted_state_dict[f"double_blocks.{i}.img_mod.lin.bias"] = original_state_dict.pop( f"{block_prefix}norm1_a.linear.bias" ) ## norm1_b converted_state_dict[f"double_blocks.{i}.txt_mod.lin.weight"] = original_state_dict.pop( f"{block_prefix}norm1_b.linear.weight" ) converted_state_dict[f"double_blocks.{i}.txt_mod.lin.bias"] = original_state_dict.pop( f"{block_prefix}norm1_b.linear.bias" ) # qkv converted_state_dict[f"double_blocks.{i}.img_attn.qkv.weight"] = original_state_dict.pop( f"{block_prefix}attn.a_to_qkv.weight" ) converted_state_dict[f"double_blocks.{i}.img_attn.qkv.bias"] = original_state_dict.pop( f"{block_prefix}attn.a_to_qkv.bias" ) converted_state_dict[f"double_blocks.{i}.txt_attn.qkv.weight"] = original_state_dict.pop( f"{block_prefix}attn.b_to_qkv.weight" ) converted_state_dict[f"double_blocks.{i}.txt_attn.qkv.bias"] = original_state_dict.pop( f"{block_prefix}attn.b_to_qkv.bias" ) # qk_norm converted_state_dict[f"double_blocks.{i}.img_attn.norm.query_norm.scale"] = original_state_dict.pop( f"{block_prefix}attn.norm_q_a.weight" ) converted_state_dict[f"double_blocks.{i}.img_attn.norm.key_norm.scale"] = original_state_dict.pop( f"{block_prefix}attn.norm_k_a.weight" ) converted_state_dict[f"double_blocks.{i}.txt_attn.norm.query_norm.scale"] = original_state_dict.pop( f"{block_prefix}attn.norm_q_b.weight" ) converted_state_dict[f"double_blocks.{i}.txt_attn.norm.key_norm.scale"] = original_state_dict.pop( f"{block_prefix}attn.norm_k_b.weight" ) # ff converted_state_dict[f"double_blocks.{i}.img_mlp.0.weight"] = original_state_dict.pop( f"{block_prefix}ff_a.0.weight" ) converted_state_dict[f"double_blocks.{i}.img_mlp.0.bias"] = original_state_dict.pop( f"{block_prefix}ff_a.0.bias" ) converted_state_dict[f"double_blocks.{i}.img_mlp.2.weight"] = original_state_dict.pop( f"{block_prefix}ff_a.2.weight" ) converted_state_dict[f"double_blocks.{i}.img_mlp.2.bias"] = original_state_dict.pop( f"{block_prefix}ff_a.2.bias" ) converted_state_dict[f"double_blocks.{i}.txt_mlp.0.weight"] = original_state_dict.pop( f"{block_prefix}ff_b.0.weight" ) converted_state_dict[f"double_blocks.{i}.txt_mlp.0.bias"] = original_state_dict.pop( f"{block_prefix}ff_b.0.bias" ) converted_state_dict[f"double_blocks.{i}.txt_mlp.2.weight"] = original_state_dict.pop( f"{block_prefix}ff_b.2.weight" ) converted_state_dict[f"double_blocks.{i}.txt_mlp.2.bias"] = original_state_dict.pop( f"{block_prefix}ff_b.2.bias" ) # output projections. converted_state_dict[f"double_blocks.{i}.img_attn.proj.weight"] = original_state_dict.pop( f"{block_prefix}attn.a_to_out.weight" ) converted_state_dict[f"double_blocks.{i}.img_attn.proj.bias"] = original_state_dict.pop( f"{block_prefix}attn.a_to_out.bias" ) converted_state_dict[f"double_blocks.{i}.txt_attn.proj.weight"] = original_state_dict.pop( f"{block_prefix}attn.b_to_out.weight" ) converted_state_dict[f"double_blocks.{i}.txt_attn.proj.bias"] = original_state_dict.pop( f"{block_prefix}attn.b_to_out.bias" ) # single transformer blocks for i in range(num_single_layers): block_prefix = f"single_blocks.{i}." # norm.linear <- single_blocks.0.modulation.lin converted_state_dict[f"single_blocks.{i}.modulation.lin.weight"] = original_state_dict.pop( f"{block_prefix}norm.linear.weight" ) converted_state_dict[f"single_blocks.{i}.modulation.lin.bias"] = original_state_dict.pop( f"{block_prefix}norm.linear.bias" ) # qkv converted_state_dict[f"single_blocks.{i}.linear1.weight"] = original_state_dict.pop( f"{block_prefix}to_qkv_mlp.weight" ) converted_state_dict[f"single_blocks.{i}.linear1.bias"] = original_state_dict.pop( f"{block_prefix}to_qkv_mlp.bias" ) # qk norm converted_state_dict[ f"single_blocks.{i}.norm.query_norm.scale"] = original_state_dict.pop( f"{block_prefix}norm_q_a.weight" ) converted_state_dict[f"single_blocks.{i}.norm.key_norm.scale"] = original_state_dict.pop( f"{block_prefix}norm_k_a.weight" ) # output projections. converted_state_dict[f"single_blocks.{i}.linear2.weight"] = original_state_dict.pop( f"{block_prefix}proj_out.weight" ) converted_state_dict[f"single_blocks.{i}.linear2.bias"] = original_state_dict.pop( f"{block_prefix}proj_out.bias" ) converted_state_dict["final_layer.linear.weight"] = original_state_dict.pop("final_proj_out.weight") converted_state_dict["final_layer.linear.bias"] = original_state_dict.pop("final_proj_out.bias") converted_state_dict["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(original_state_dict.pop("final_norm_out.linear.weight")) converted_state_dict["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(original_state_dict.pop("final_norm_out.linear.bias")) # # converted_state_dict["final_norm_out.linear.weight"] = swap_scale_shift( # original_state_dict.pop("final_layer.adaLN_modulation.1.weight") # ) # converted_state_dict["final_norm_out.linear.bias"] = swap_scale_shift( # original_state_dict.pop("final_layer.adaLN_modulation.1.bias") # ) return converted_state_dict def main(): state_dict = load_state_dict("xxxx/step-40000.safetensors") convert_state_dict = convert_diffusers_to_flux(state_dict, 19, 38) save_file(convert_state_dict, "xxx/step-40000-convert.safetensors") if __name__ == "__main__": main()
the script works fine, thanks for great job!
@Artiprocher 或许辛苦开发者考虑集成一下不?我看lora有一个参数叫Align_to_open_source_format?