DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

Flux full train的模型转opensource format

Open justfuwei opened this issue 5 months ago • 4 comments

Flux full train的模型在comfyui中无法加载,利用diffusers的脚本也无法转成diffusers格式,这个是否能支持

justfuwei avatar Sep 25 '25 02:09 justfuwei

@justfuwei 暂不支持

Artiprocher avatar Sep 25 '25 07:09 Artiprocher

Flux full train的模型在comfyui中无法加载,利用diffusers的脚本也无法转成diffusers格式,这个是否能支持

Maybe you could try this script for Flux dev? Please let me know if this script works for you.

from diffsynth import load_state_dict
from safetensors.torch import save_file

def swap_scale_shift(weight):
    shift, scale = weight.chunk(2, dim=0)
    new_weight = torch.cat([scale, shift], dim=0)
    return new_weight

def convert_diffusers_to_flux(
    original_state_dict, num_layers, num_single_layers
):
    converted_state_dict = {}

    ## time_embedder.timestep_embedder -> time_in
    converted_state_dict["time_in.in_layer.weight"] = original_state_dict.pop(
        "time_embedder.timestep_embedder.0.weight"
    )
    converted_state_dict[ "time_in.in_layer.bias" ] = original_state_dict.pop(
        "time_embedder.timestep_embedder.0.bias"
    )
    converted_state_dict["time_in.out_layer.weight"] = original_state_dict.pop(
        "time_embedder.timestep_embedder.2.weight"
    )
    converted_state_dict[ "time_in.out_layer.bias" ] = original_state_dict.pop(
        "time_embedder.timestep_embedder.2.bias"
    )

    ## pooled_text_embedder -> vector_in
    converted_state_dict["vector_in.in_layer.weight"] = original_state_dict.pop(
        "pooled_text_embedder.0.weight"
    )
    converted_state_dict["vector_in.in_layer.bias"] = original_state_dict.pop(
        "pooled_text_embedder.0.bias"
    )
    converted_state_dict[ "vector_in.out_layer.weight"] = original_state_dict.pop(
        "pooled_text_embedder.2.weight"
    )
    converted_state_dict["vector_in.out_layer.bias"] = original_state_dict.pop(
        "pooled_text_embedder.2.bias"
    )

    # guidance
    has_guidance = any("guidance" in k for k in original_state_dict)
    if has_guidance:
        converted_state_dict["guidance_in.in_layer.weight"] = original_state_dict.pop(
            "guidance_embedder.timestep_embedder.0.weight"
        )
        converted_state_dict["guidance_in.in_layer.bias"] = original_state_dict.pop(
            "guidance_embedder.timestep_embedder.0.bias"
        )
        converted_state_dict["guidance_in.out_layer.weight"] = original_state_dict.pop(
            "guidance_embedder.timestep_embedder.2.weight"
        )
        converted_state_dict["guidance_in.out_layer.bias"] = original_state_dict.pop(
            "guidance_embedder.timestep_embedder.2.bias"
        )

    # context_embedder
    converted_state_dict["txt_in.weight"] = original_state_dict.pop("context_embedder.weight")
    converted_state_dict["txt_in.bias"] = original_state_dict.pop("context_embedder.bias")

    # x_embedder
    converted_state_dict["img_in.weight"] = original_state_dict.pop("x_embedder.weight")
    converted_state_dict["img_in.bias"] = original_state_dict.pop("x_embedder.bias")

    # double transformer blocks
    for i in range(num_layers):
        block_prefix = f"blocks.{i}."
        # norms.
        ## norm1_a
        converted_state_dict[f"double_blocks.{i}.img_mod.lin.weight"] = original_state_dict.pop(
             f"{block_prefix}norm1_a.linear.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_mod.lin.bias"] = original_state_dict.pop(
            f"{block_prefix}norm1_a.linear.bias"
        )
        ## norm1_b
        converted_state_dict[f"double_blocks.{i}.txt_mod.lin.weight"] = original_state_dict.pop(
            f"{block_prefix}norm1_b.linear.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_mod.lin.bias"] = original_state_dict.pop(
            f"{block_prefix}norm1_b.linear.bias"
        )
        # qkv
        converted_state_dict[f"double_blocks.{i}.img_attn.qkv.weight"] = original_state_dict.pop(
            f"{block_prefix}attn.a_to_qkv.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_attn.qkv.bias"] = original_state_dict.pop(
            f"{block_prefix}attn.a_to_qkv.bias"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.qkv.weight"] = original_state_dict.pop(
            f"{block_prefix}attn.b_to_qkv.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.qkv.bias"] = original_state_dict.pop(
            f"{block_prefix}attn.b_to_qkv.bias"
        )

        # qk_norm
        converted_state_dict[f"double_blocks.{i}.img_attn.norm.query_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}attn.norm_q_a.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_attn.norm.key_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}attn.norm_k_a.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.norm.query_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}attn.norm_q_b.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.norm.key_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}attn.norm_k_b.weight"
        )
        # ff
        converted_state_dict[f"double_blocks.{i}.img_mlp.0.weight"] = original_state_dict.pop(
            f"{block_prefix}ff_a.0.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_mlp.0.bias"] = original_state_dict.pop(
            f"{block_prefix}ff_a.0.bias"
        )
        converted_state_dict[f"double_blocks.{i}.img_mlp.2.weight"] = original_state_dict.pop(
            f"{block_prefix}ff_a.2.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_mlp.2.bias"] = original_state_dict.pop(
            f"{block_prefix}ff_a.2.bias"
        )
        converted_state_dict[f"double_blocks.{i}.txt_mlp.0.weight"] = original_state_dict.pop(
            f"{block_prefix}ff_b.0.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_mlp.0.bias"] = original_state_dict.pop(
            f"{block_prefix}ff_b.0.bias"
        )
        converted_state_dict[f"double_blocks.{i}.txt_mlp.2.weight"] = original_state_dict.pop(
            f"{block_prefix}ff_b.2.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_mlp.2.bias"] = original_state_dict.pop(
            f"{block_prefix}ff_b.2.bias"
        )
        # output projections.
        converted_state_dict[f"double_blocks.{i}.img_attn.proj.weight"] = original_state_dict.pop(
            f"{block_prefix}attn.a_to_out.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_attn.proj.bias"] = original_state_dict.pop(
            f"{block_prefix}attn.a_to_out.bias"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.proj.weight"] = original_state_dict.pop(
            f"{block_prefix}attn.b_to_out.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.proj.bias"] = original_state_dict.pop(
            f"{block_prefix}attn.b_to_out.bias"
        )

    # single transformer blocks
    for i in range(num_single_layers):
        block_prefix = f"single_blocks.{i}."
        # norm.linear  <- single_blocks.0.modulation.lin
        converted_state_dict[f"single_blocks.{i}.modulation.lin.weight"] = original_state_dict.pop(
            f"{block_prefix}norm.linear.weight"
        )
        converted_state_dict[f"single_blocks.{i}.modulation.lin.bias"] = original_state_dict.pop(
            f"{block_prefix}norm.linear.bias"
        )
        # qkv
        converted_state_dict[f"single_blocks.{i}.linear1.weight"] = original_state_dict.pop(
            f"{block_prefix}to_qkv_mlp.weight"
        )
        converted_state_dict[f"single_blocks.{i}.linear1.bias"] = original_state_dict.pop(
            f"{block_prefix}to_qkv_mlp.bias"
        )
        # qk norm
        converted_state_dict[ f"single_blocks.{i}.norm.query_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}norm_q_a.weight"
        )
        converted_state_dict[f"single_blocks.{i}.norm.key_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}norm_k_a.weight"
        )
        # output projections.
        converted_state_dict[f"single_blocks.{i}.linear2.weight"] = original_state_dict.pop(
            f"{block_prefix}proj_out.weight"
        )
        converted_state_dict[f"single_blocks.{i}.linear2.bias"] = original_state_dict.pop(
            f"{block_prefix}proj_out.bias"
        )

    converted_state_dict["final_layer.linear.weight"] = original_state_dict.pop("final_proj_out.weight")
    converted_state_dict["final_layer.linear.bias"] = original_state_dict.pop("final_proj_out.bias")

    converted_state_dict["final_layer.adaLN_modulation.1.weight"] =  swap_scale_shift(original_state_dict.pop("final_norm_out.linear.weight"))
    converted_state_dict["final_layer.adaLN_modulation.1.bias"] =  swap_scale_shift(original_state_dict.pop("final_norm_out.linear.bias"))

    #
    # converted_state_dict["final_norm_out.linear.weight"] = swap_scale_shift(
    #     original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
    # )
    # converted_state_dict["final_norm_out.linear.bias"] = swap_scale_shift(
    #     original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
    # )

    return converted_state_dict

def main():
    state_dict = load_state_dict("xxxx/step-40000.safetensors")
    convert_state_dict = convert_diffusers_to_flux(state_dict, 19, 38)
    save_file(convert_state_dict, "xxx/step-40000-convert.safetensors")

if __name__ == "__main__":
    main()

wormys avatar Oct 09 '25 07:10 wormys

Flux full train的模型在comfyui中无法加载,利用diffusers的脚本也无法转成diffusers格式,这个是否能支持

Maybe you could try this script for Flux dev? Please let me know if this script works for you.

from diffsynth import load_state_dict
from safetensors.torch import save_file

def swap_scale_shift(weight):
    shift, scale = weight.chunk(2, dim=0)
    new_weight = torch.cat([scale, shift], dim=0)
    return new_weight

def convert_diffusers_to_flux(
    original_state_dict, num_layers, num_single_layers
):
    converted_state_dict = {}

    ## time_embedder.timestep_embedder -> time_in
    converted_state_dict["time_in.in_layer.weight"] = original_state_dict.pop(
        "time_embedder.timestep_embedder.0.weight"
    )
    converted_state_dict[ "time_in.in_layer.bias" ] = original_state_dict.pop(
        "time_embedder.timestep_embedder.0.bias"
    )
    converted_state_dict["time_in.out_layer.weight"] = original_state_dict.pop(
        "time_embedder.timestep_embedder.2.weight"
    )
    converted_state_dict[ "time_in.out_layer.bias" ] = original_state_dict.pop(
        "time_embedder.timestep_embedder.2.bias"
    )

    ## pooled_text_embedder -> vector_in
    converted_state_dict["vector_in.in_layer.weight"] = original_state_dict.pop(
        "pooled_text_embedder.0.weight"
    )
    converted_state_dict["vector_in.in_layer.bias"] = original_state_dict.pop(
        "pooled_text_embedder.0.bias"
    )
    converted_state_dict[ "vector_in.out_layer.weight"] = original_state_dict.pop(
        "pooled_text_embedder.2.weight"
    )
    converted_state_dict["vector_in.out_layer.bias"] = original_state_dict.pop(
        "pooled_text_embedder.2.bias"
    )

    # guidance
    has_guidance = any("guidance" in k for k in original_state_dict)
    if has_guidance:
        converted_state_dict["guidance_in.in_layer.weight"] = original_state_dict.pop(
            "guidance_embedder.timestep_embedder.0.weight"
        )
        converted_state_dict["guidance_in.in_layer.bias"] = original_state_dict.pop(
            "guidance_embedder.timestep_embedder.0.bias"
        )
        converted_state_dict["guidance_in.out_layer.weight"] = original_state_dict.pop(
            "guidance_embedder.timestep_embedder.2.weight"
        )
        converted_state_dict["guidance_in.out_layer.bias"] = original_state_dict.pop(
            "guidance_embedder.timestep_embedder.2.bias"
        )

    # context_embedder
    converted_state_dict["txt_in.weight"] = original_state_dict.pop("context_embedder.weight")
    converted_state_dict["txt_in.bias"] = original_state_dict.pop("context_embedder.bias")

    # x_embedder
    converted_state_dict["img_in.weight"] = original_state_dict.pop("x_embedder.weight")
    converted_state_dict["img_in.bias"] = original_state_dict.pop("x_embedder.bias")

    # double transformer blocks
    for i in range(num_layers):
        block_prefix = f"blocks.{i}."
        # norms.
        ## norm1_a
        converted_state_dict[f"double_blocks.{i}.img_mod.lin.weight"] = original_state_dict.pop(
             f"{block_prefix}norm1_a.linear.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_mod.lin.bias"] = original_state_dict.pop(
            f"{block_prefix}norm1_a.linear.bias"
        )
        ## norm1_b
        converted_state_dict[f"double_blocks.{i}.txt_mod.lin.weight"] = original_state_dict.pop(
            f"{block_prefix}norm1_b.linear.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_mod.lin.bias"] = original_state_dict.pop(
            f"{block_prefix}norm1_b.linear.bias"
        )
        # qkv
        converted_state_dict[f"double_blocks.{i}.img_attn.qkv.weight"] = original_state_dict.pop(
            f"{block_prefix}attn.a_to_qkv.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_attn.qkv.bias"] = original_state_dict.pop(
            f"{block_prefix}attn.a_to_qkv.bias"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.qkv.weight"] = original_state_dict.pop(
            f"{block_prefix}attn.b_to_qkv.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.qkv.bias"] = original_state_dict.pop(
            f"{block_prefix}attn.b_to_qkv.bias"
        )

        # qk_norm
        converted_state_dict[f"double_blocks.{i}.img_attn.norm.query_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}attn.norm_q_a.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_attn.norm.key_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}attn.norm_k_a.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.norm.query_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}attn.norm_q_b.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.norm.key_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}attn.norm_k_b.weight"
        )
        # ff
        converted_state_dict[f"double_blocks.{i}.img_mlp.0.weight"] = original_state_dict.pop(
            f"{block_prefix}ff_a.0.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_mlp.0.bias"] = original_state_dict.pop(
            f"{block_prefix}ff_a.0.bias"
        )
        converted_state_dict[f"double_blocks.{i}.img_mlp.2.weight"] = original_state_dict.pop(
            f"{block_prefix}ff_a.2.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_mlp.2.bias"] = original_state_dict.pop(
            f"{block_prefix}ff_a.2.bias"
        )
        converted_state_dict[f"double_blocks.{i}.txt_mlp.0.weight"] = original_state_dict.pop(
            f"{block_prefix}ff_b.0.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_mlp.0.bias"] = original_state_dict.pop(
            f"{block_prefix}ff_b.0.bias"
        )
        converted_state_dict[f"double_blocks.{i}.txt_mlp.2.weight"] = original_state_dict.pop(
            f"{block_prefix}ff_b.2.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_mlp.2.bias"] = original_state_dict.pop(
            f"{block_prefix}ff_b.2.bias"
        )
        # output projections.
        converted_state_dict[f"double_blocks.{i}.img_attn.proj.weight"] = original_state_dict.pop(
            f"{block_prefix}attn.a_to_out.weight"
        )
        converted_state_dict[f"double_blocks.{i}.img_attn.proj.bias"] = original_state_dict.pop(
            f"{block_prefix}attn.a_to_out.bias"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.proj.weight"] = original_state_dict.pop(
            f"{block_prefix}attn.b_to_out.weight"
        )
        converted_state_dict[f"double_blocks.{i}.txt_attn.proj.bias"] = original_state_dict.pop(
            f"{block_prefix}attn.b_to_out.bias"
        )

    # single transformer blocks
    for i in range(num_single_layers):
        block_prefix = f"single_blocks.{i}."
        # norm.linear  <- single_blocks.0.modulation.lin
        converted_state_dict[f"single_blocks.{i}.modulation.lin.weight"] = original_state_dict.pop(
            f"{block_prefix}norm.linear.weight"
        )
        converted_state_dict[f"single_blocks.{i}.modulation.lin.bias"] = original_state_dict.pop(
            f"{block_prefix}norm.linear.bias"
        )
        # qkv
        converted_state_dict[f"single_blocks.{i}.linear1.weight"] = original_state_dict.pop(
            f"{block_prefix}to_qkv_mlp.weight"
        )
        converted_state_dict[f"single_blocks.{i}.linear1.bias"] = original_state_dict.pop(
            f"{block_prefix}to_qkv_mlp.bias"
        )
        # qk norm
        converted_state_dict[ f"single_blocks.{i}.norm.query_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}norm_q_a.weight"
        )
        converted_state_dict[f"single_blocks.{i}.norm.key_norm.scale"] = original_state_dict.pop(
            f"{block_prefix}norm_k_a.weight"
        )
        # output projections.
        converted_state_dict[f"single_blocks.{i}.linear2.weight"] = original_state_dict.pop(
            f"{block_prefix}proj_out.weight"
        )
        converted_state_dict[f"single_blocks.{i}.linear2.bias"] = original_state_dict.pop(
            f"{block_prefix}proj_out.bias"
        )

    converted_state_dict["final_layer.linear.weight"] = original_state_dict.pop("final_proj_out.weight")
    converted_state_dict["final_layer.linear.bias"] = original_state_dict.pop("final_proj_out.bias")

    converted_state_dict["final_layer.adaLN_modulation.1.weight"] =  swap_scale_shift(original_state_dict.pop("final_norm_out.linear.weight"))
    converted_state_dict["final_layer.adaLN_modulation.1.bias"] =  swap_scale_shift(original_state_dict.pop("final_norm_out.linear.bias"))

    #
    # converted_state_dict["final_norm_out.linear.weight"] = swap_scale_shift(
    #     original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
    # )
    # converted_state_dict["final_norm_out.linear.bias"] = swap_scale_shift(
    #     original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
    # )

    return converted_state_dict

def main():
    state_dict = load_state_dict("xxxx/step-40000.safetensors")
    convert_state_dict = convert_diffusers_to_flux(state_dict, 19, 38)
    save_file(convert_state_dict, "xxx/step-40000-convert.safetensors")

if __name__ == "__main__":
    main()

the script works fine, thanks for great job!

benlee12 avatar Oct 20 '25 12:10 benlee12

@Artiprocher 或许辛苦开发者考虑集成一下不?我看lora有一个参数叫Align_to_open_source_format?

wormys avatar Oct 20 '25 12:10 wormys