diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Flux1-Dev inference with single file ComfyUI/SD-Forge Safetensors

Open ddpasa opened this issue 1 month ago • 12 comments

Is it possible to run inference with diffusers using a single-file safetensors created for ComfyUI/SD-Forge?

It looks like FluxPipeline.from_single_file() might be intended for this purpose, but I'm getting the following errors:

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_single_file("./flux1-dev-fp8.safetensors", torch_dtype=torch.float8_e4m3fn, use_safetensors=True)
Traceback (most recent call last):
  File "/home/user/flux/imgen.py", line 9, in <module>
    pipe = FluxPipeline.from_single_file("./flux1-dev-fp8.safetensors", torch_dtype=torch.float8_e4m3fn, use_safetensors=True)
  File "/home/user/.local/lib/python3.13/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/user/.local/lib/python3.13/site-packages/diffusers/loaders/single_file.py", line 509, in from_single_file
    loaded_sub_model = load_single_file_sub_model(
        library_name=library_name,
    ...<11 lines>...
        **kwargs,
    )
  File "/home/user/.local/lib/python3.13/site-packages/diffusers/loaders/single_file.py", line 127, in load_single_file_sub_model
    loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
        class_obj,
    ...<4 lines>...
        local_files_only=local_files_only,
    )
  File "/home/user/.local/lib/python3.13/site-packages/diffusers/loaders/single_file_utils.py", line 2156, in create_diffusers_t5_model_from_checkpoint
    model.load_state_dict(diffusers_format_checkpoint)
    ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/.local/lib/python3.13/site-packages/torch/nn/modules/module.py", line 2641, in load_state_dict
    raise RuntimeError(
    ...<3 lines>...
    )
RuntimeError: Error(s) in loading state_dict for T5EncoderModel:
	Missing key(s) in state_dict: "encoder.embed_tokens.weight". 

I checked the safetensors file and the T5 encoder is present. However, it is named differently, which confuses diffusers.

ddpasa avatar Nov 16 '25 11:11 ddpasa

@ddpasa It's possible to load the Flux transformer individually from the the file.

from diffusers import FluxTransformer2DModel
model = FluxTransformer2DModel.from_single_file("./flux1-dev-fp8.safetensors", torch_dtype=torch.float8_e4m3fn)

Just FYI, diffusers doesn't handle automatic casting during inference if the weights are loaded in FP8. To enable that behaviour, you should use layerwise casting

from diffusers import FluxTransformer2DModel
model = FluxTransformer2DModel.from_single_file("./flux1-dev-fp8.safetensors", torch_dtype=torch.bfloat16)
model.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)

Re: The text encoder, I'm afraid if the keys are different, the pipeline won't be able to load the T5 encoder. You could either convert the keys to match the HF T5 Encoder or use the model in the official repo.

import torch
from diffusers import FluxPipeline
from transformers import T5EncoderModel

ckpt_path = "<path to checkpoint file>"
pipe = FluxPipeline.from_single_file(
    ckpt_path,
    text_encoder_2=None,
    torch_dtype=torch.bfloat16
)
pipe.text_encoder_2 = T5EncoderModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="text_encoder_2",
    torch_dtype=torch.bfloat16
)

DN6 avatar Nov 17 '25 07:11 DN6

thanks @DN6 !

This is not the original Flux model, but a finetune trained with sd-scripts. That is why I'm trying to load it this way.

ddpasa avatar Nov 17 '25 08:11 ddpasa

I think the fine tune would be just for the transformer (finetuning the T5 model is not very common). So I think you should be able to safely load just the transformer weights from your checkpoint and use the T5 in the official repo.

DN6 avatar Nov 27 '25 09:11 DN6

I think the fine tune would be just for the transformer (finetuning the T5 model is not very common). So I think you should be able to safely load just the transformer weights from your checkpoint and use the T5 in the official repo.

Why not? It's just a single flag in sd-scripts to also train the clip or t5xxl encoders? Works well with model performance too.

ddpasa avatar Nov 28 '25 10:11 ddpasa

@ddpasa encoder.embed_tokens.weight is probably the only affected key, see: https://github.com/huggingface/transformers/blob/6db4332171df2b4099c44c7a5c01258b91f7394a/src/transformers/models/t5/modeling_t5.py#L1181-L1186

This should be a fairly simple fix to convert_sd3_t5_checkpoint_to_diffusers, something like:

diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index d4676ba25..2678976e9 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -2124,11 +2124,15 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
     text_model_dict = {}
 
     remove_prefixes = ["text_encoders.t5xxl.transformer."]
+    modify_keys = {"encoder.embed_tokens.weight": "encoder.shared.weight"}
 
     for key in keys:
         for prefix in remove_prefixes:
             if key.startswith(prefix):
                 diffusers_key = key.replace(prefix, "")
+                for original, modified in modify_keys.items():
+                    if diffusers_key.startswith(original):
+                        diffusers_key = diffusers_key.replace(original, modified)
                 text_model_dict[diffusers_key] = checkpoint.get(key)
 
     return text_model_dict

hlky avatar Nov 28 '25 11:11 hlky

Thanks @hlky , the parameters look like this on the safetensors:

...
text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias
text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight
text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias
text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight
text_encoders.clip_l.transformer.text_model.final_layer_norm.bias
text_encoders.clip_l.transformer.text_model.final_layer_norm.weight
text_encoders.clip_l.transformer.text_projection.weight
text_encoders.t5xxl.logit_scale
text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight
text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.o.weight
text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.q.weight
text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight
text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.v.weight
text_encoders.t5xxl.transformer.encoder.block.0.layer.0.layer_norm.weight
text_encoders.t5xxl.transformer.encoder.block.0.layer.1.DenseReluDense.wi_0.weight
text_encoders.t5xxl.transformer.encoder.block.0.layer.1.DenseReluDense.wi_1.weight
text_encoders.t5xxl.transformer.encoder.block.0.layer.1.DenseReluDense.wo.weight
text_encoders.t5xxl.transformer.encoder.block.0.layer.1.layer_norm.weight
text_encoders.t5xxl.transformer.encoder.block.1.layer.0.SelfAttention.k.weight
text_encoders.t5xxl.transformer.encoder.block.1.layer.0.SelfAttention.o.weight
...

ddpasa avatar Nov 28 '25 17:11 ddpasa

@ddpasa Is there a text_encoders.t5xxl.transformer.shared.weight?

hlky avatar Nov 28 '25 18:11 hlky

@ddpasa Is there a text_encoders.t5xxl.transformer.shared.weight?

@hlky yes there is. I uploaded a prinout of the whole state dict keys here.

flux.txt

ddpasa avatar Nov 28 '25 19:11 ddpasa

@ddpasa Thanks, on second look the easiest fix for this issue is to install accelerate, this is indeed broken with the model.load_state_dict path, but works with the load_model_dict_into_meta path when accelerate is available.

On Diffusers end the model.load_state_dict path can be fixed like:

diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index d4676ba25..164d70f50 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -2131,6 +2131,8 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
                 diffusers_key = key.replace(prefix, "")
                 text_model_dict[diffusers_key] = checkpoint.get(key)
 
+    text_model_dict["encoder.embed_tokens.weight"] = text_model_dict["shared.weight"]
+
     return text_model_dict
 

The accelerate load_model_dict_into_meta path still works with this change so should be a safe change to make.

There are however a number of warnings (for every key):

UserWarning: for shared.weight: copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op. (Did you mean to pass `assign=True` to assign items in the state dictionary to their corresponding key in the module instead of copying them in place?)

So assign=True should be added to model.load_state_dict:

diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index d4676ba25..3e115a344 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -2158,7 +2158,7 @@ def create_diffusers_t5_model_from_checkpoint(
         load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
         empty_device_cache()
     else:
-        model.load_state_dict(diffusers_format_checkpoint)
+        model.load_state_dict(diffusers_format_checkpoint, assign=True)
 
     use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
     if use_keep_in_fp32_modules:

Test with:

import torch
from diffusers import FluxPipeline
from huggingface_hub import hf_hub_download

checkpoint = hf_hub_download("Comfy-Org/flux1-dev", "flux1-dev-fp8.safetensors")

pipe = FluxPipeline.from_single_file(checkpoint, torch_dtype=torch.float8_e4m3fn)

tl;dr pip install accelerate

hlky avatar Nov 28 '25 20:11 hlky

Thanks @hlky I can confirm that pip installing accelerate makes the single file loader to work. It does download something from hf hub though, maybe the config?

Would the same approach work for Lora's as well? The problem right now is that if you load a Kohya/ComfyUI Flux1-dev Lora into diffusers, the clip-l and t5xxl components of the lora are silently ignored and only the other parts are loaded.

ddpasa avatar Nov 29 '25 14:11 ddpasa

Hi @ddpasa yes we would auto download the model config when using single file. Additionally, what version of transformers are you using to load the model?

DN6 avatar Dec 03 '25 06:12 DN6

Hi @ddpasa yes we would auto download the model config when using single file. Additionally, what version of transformers are you using to load the model?

I git-installed the latest dev version from a few weeks ago:

$ pip list | grep -i diffusers
diffusers                  0.36.0.dev0

Is it possible to run in an completely offline mode so it doesn't need to download anything?

ddpasa avatar Dec 03 '25 16:12 ddpasa