quanto
quanto copied to clipboard
LayerNorm with None weight throws exception
https://github.com/huggingface/optimum-quanto/blob/b0cce2435f0b72d8d8a6f0dc6b18dc409160b394/optimum/quanto/nn/qlayernorm.py#L44
LayerNorm with None
weights will raise here.
Flux(
(pe_embedder): EmbedND()
(img_in): Linear(in_features=64, out_features=3072, bias=True)
(time_in): MLPEmbedder(
(in_layer): Linear(in_features=256, out_features=3072, bias=True)
(silu): SiLU()
(out_layer): Linear(in_features=3072, out_features=3072, bias=True)
)
(vector_in): MLPEmbedder(
(in_layer): Linear(in_features=768, out_features=3072, bias=True)
(silu): SiLU()
(out_layer): Linear(in_features=3072, out_features=3072, bias=True)
)
(guidance_in): MLPEmbedder(
(in_layer): Linear(in_features=256, out_features=3072, bias=True)
(silu): SiLU()
(out_layer): Linear(in_features=3072, out_features=3072, bias=True)
)
(txt_in): Linear(in_features=4096, out_features=3072, bias=True)
(double_blocks): ModuleList(
(0-18): 19 x DoubleStreamBlock(
(img_mod): Modulation(
(lin): Linear(in_features=3072, out_features=18432, bias=True)
)
(img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(img_attn): SelfAttention(
(qkv): Linear(in_features=3072, out_features=9216, bias=True)
(norm): QKNorm(
(query_norm): RMSNorm()
(key_norm): RMSNorm()
)
(proj): Linear(in_features=3072, out_features=3072, bias=True)
)
(img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(img_mlp): Sequential(
(0): Linear(in_features=3072, out_features=12288, bias=True)
(1): GELU(approximate='tanh')
(2): Linear(in_features=12288, out_features=3072, bias=True)
)
(txt_mod): Modulation(
(lin): Linear(in_features=3072, out_features=18432, bias=True)
)
(txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(txt_attn): SelfAttention(
(qkv): Linear(in_features=3072, out_features=9216, bias=True)
(norm): QKNorm(
(query_norm): RMSNorm()
(key_norm): RMSNorm()
)
(proj): Linear(in_features=3072, out_features=3072, bias=True)
)
(txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(txt_mlp): Sequential(
(0): Linear(in_features=3072, out_features=12288, bias=True)
(1): GELU(approximate='tanh')
(2): Linear(in_features=12288, out_features=3072, bias=True)
)
)
)
(single_blocks): ModuleList(
(0-37): 38 x SingleStreamBlock(
(linear1): Linear(in_features=3072, out_features=21504, bias=True)
(linear2): Linear(in_features=15360, out_features=3072, bias=True)
(norm): QKNorm(
(query_norm): RMSNorm()
(key_norm): RMSNorm()
)
(pre_norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(mlp_act): GELU(approximate='tanh')
(modulation): Modulation(
(lin): Linear(in_features=3072, out_features=9216, bias=True)
)
)
)
(final_layer): LastLayer(
(norm_final): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(linear): Linear(in_features=3072, out_features=64, bias=True)
(adaLN_modulation): Sequential(
(0): SiLU()
(1): Linear(in_features=3072, out_features=6144, bias=True)
)
)
)
this is a perfectly fine workaround for excluding the modules:
...
exclude=[name for name, module in unet.named_modules() if isinstance(module, torch.nn.LayerNorm) and module.weight is None
]