diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Error(s) in initializing SD3ControlNetModel by from_transformer

Open ChenhLiwnl opened this issue 1 year ago • 9 comments

Describe the bug

WechatIMG659

Reproduction

from diffusers.models.controlnet_sd3 import SD3ControlNetModel
from diffusers.models.transformers import SD3Transformer2DModel
transformer = SD3Transformer2DModel.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", subfolder="transformer")
controlnet = SD3ControlNetModel.from_transformer(transformer)

Logs

No response

System Info

  • 🤗 Diffusers version: 0.29.2
  • Platform: Linux-4.18.0
  • Running on a notebook?: No
  • Running on Google Colab?: No
  • Python version: 3.8.3
  • PyTorch version (GPU?): 2.0.1 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.23.4
  • Transformers version: 4.41.2
  • Accelerate version: 0.23.0
  • PEFT version: installed
  • Bitsandbytes version: installed
  • xFormers version: not installed NVIDIA A100 80GB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

ChenhLiwnl avatar Jun 27 '24 13:06 ChenhLiwnl

Cc: @haofanwang

sayakpaul avatar Jun 28 '24 02:06 sayakpaul

yeah I think we reversed the strict=False and strict=True, can you try this? and if it works, would you be willing to open a PR? https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/controlnet_sd3.py#L242

        config = transformer.config
        config["num_layers"] = num_layers or config.num_layers
        controlnet = cls(**config)

        if load_weights_from_transformer:
            controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
            controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
            controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
            controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)

            controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)

        return controlnet

yiyixuxu avatar Jun 28 '24 19:06 yiyixuxu

yeah I think we reversed the strict=False and strict=True, can you try this? and if it works, would you be willing to open a PR? https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/controlnet_sd3.py#L242

        config = transformer.config
        config["num_layers"] = num_layers or config.num_layers
        controlnet = cls(**config)

        if load_weights_from_transformer:
            controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
            controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
            controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
            controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)

            controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)

        return controlnet

It still doesnt work and reports same bug

ChenhLiwnl avatar Jul 01 '24 00:07 ChenhLiwnl

yeah I think we reversed the strict=False and strict=True, can you try this? and if it works, would you be willing to open a PR? https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/controlnet_sd3.py#L242

        config = transformer.config
        config["num_layers"] = num_layers or config.num_layers
        controlnet = cls(**config)

        if load_weights_from_transformer:
            controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
            controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
            controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
            controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)

            controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)

        return controlnet

controlnet.transformer_blocks and transformer.transformer_blocks are different at the last block and I'm trying to find out why. I think that is the reason? so i set the last layer of controlnet.transformer_blocks as JointTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=self.inner_dim, context_pre_only=True, ) and this time it works. But I'm not very sure it is correct to do so

ChenhLiwnl avatar Jul 01 '24 02:07 ChenhLiwnl

SD3's transformer blocks are: self.transformer_blocks = nn.ModuleList( [ JointTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.inner_dim, context_pre_only=i == num_layers - 1, ) for i in range(self.config.num_layers) ] ) while sd3_controlnet's transformer blocks are: self.transformer_blocks = nn.ModuleList( [ JointTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.inner_dim, context_pre_only= False, ) for i in range(self.config.num_layers) ] ) maybe this is the problem?

ChenhLiwnl avatar Jul 01 '24 03:07 ChenhLiwnl

@wangqixun Can you explain a bit here? The last block is different.

haofanwang avatar Jul 01 '24 03:07 haofanwang

@wangqixun Can you explain a bit here? The last block is different.

In fact the example controlnet (InstantX/SD3-Controlnet-Canny) can be loaded correctly So its a little bit confusing...

ChenhLiwnl avatar Jul 01 '24 05:07 ChenhLiwnl

No worry, I'm testing. Will update soon.

haofanwang avatar Jul 01 '24 06:07 haofanwang

The thing is kind of different in transformer-based ControlNet, because we don't fix the number of layers in ControlNet. While in UNet-based ControlNet, it always uses down_blocks and mid_block.

(1) If we set context_pre_only=i == num_layers - 1 in SD3ControlNetModel, we have to set num_layers as the same as SD3 base model. If not, there will be a size mismatch error. But it is not suggested, because in such case, the ControlNet is very heavy and more like a ReferenceNet. As you can see from our released checkpoints, the num_layers for ControlNet is 6 or 12, aka to half copy of UNet.

(2) So, our current solution is set num_layers in from_transformer to be 12 instead of None by default, then you can freely load weights from transformer, because we only use intermediate layers whose context_pre_only are False. The only obstacle is that we cannot set num_layers=24, as the last block is different.

In your usage, you can manually set by controlnet = SD3ControlNetModel.from_transformer(transformer, num_layers=6).

@ChenhLiwnl

haofanwang avatar Jul 01 '24 07:07 haofanwang

sorry but another question, I noticed that in main branch the transformer blocks' attention_head_dim is set to be self.config.attention_head_dim, while in v0.29.2 released version it is self.inner_dim currently self.config.attention_head_dim of the released model seem to be 64 while self.inner_dim is 1536? it seems that they are not same value, so which one is right?

ChenhLiwnl avatar Jul 04 '24 09:07 ChenhLiwnl

@ChenhLiwnl it is a bug we fixed - main is correct https://github.com/huggingface/diffusers/pull/8608

yiyixuxu avatar Jul 08 '24 20:07 yiyixuxu

closing this now since the issue is resolved! :)

yiyixuxu avatar Jul 08 '24 20:07 yiyixuxu

Hi there, when specifying num_layers in SD3ControlNetModel.from_transformer, the function results in an error of unexpected keys in state_dict due to the rest of the layers in the SD3 transformer. I think it can be fixed with a for loop that loads the weights of the specified layers. I will be grateful if you can fix that or I can fix it by myself.

FYRichie avatar Jul 30 '24 05:07 FYRichie

Could you open a new issue with a minimal reproducer? Cc: @haofanwang

sayakpaul avatar Jul 30 '24 05:07 sayakpaul

Sure. The issue is opened.

FYRichie avatar Jul 30 '24 05:07 FYRichie