[DC-AE] Add the official Deep Compression Autoencoder code(32x,64x,128x compression ratio);
What does this PR do?
This PR will add the official DC-AE (Deep Compression Autoencoder for Efficient High-Resolution Diffusion Models) into the diffusers lib. DC-AE first makes the Autoencoder is able to compress images into 32x, 64x, and 128x latent space without performance degradation. It's also an AE used by the powerful T2I base model SANA
Paper: https://arxiv.org/abs/2410.10733v1 Original code repo: https://github.com/mit-han-lab/efficientvit/tree/master/applications/dc_ae
Core contributor of DC-AE: work with @chenjy2003
Core library:
- Docs: @stevhliu and @sayakpaul
- General functionalities: @sayakpaul @yiyixuxu @DN6
We want to collaborate on this PR together with friends from HF. Feel free to contact me here. Cc: @sayakpaul
Looking forward to this @lawrence-cj!
I think it's better to collaborate on this PR together. WDYT @chenjy2003
OK. I can also help to verify the correctness.
@chenjy2003 @lawrence-cj Do you have any checkpoint that works with this PR currently, or the appropriate conversion script that can take the original checkpoint and convert it to be usable with this PR? I can load the SANA AE as expected with the original code but not with the current PR: https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0. Currently, for this PR, I see the following error:
Error
from diffusers import AutoencoderDC
from safetensors.torch import load_file
state_dict = load_file("/raid/aryan/dc-ae-sana/model.safetensors")
ae = AutoencoderDC(
in_channels=3,
latent_channels=32,
encoder_width_list=[128, 256, 512, 512, 1024, 1024],
encoder_depth_list=[2, 2, 2, 3, 3, 3],
encoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5GLU", "EViTS5GLU", "EViTS5GLU"],
encoder_norm="rms2d",
encoder_act="silu",
downsample_block_type="Conv",
decoder_width_list=[128, 256, 512, 512, 1024, 1024],
decoder_depth_list=[3, 3, 3, 3, 3, 3],
decoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5GLU", "EViTS5GLU", "EViTS5GLU"],
decoder_norm="rms2d",
decoder_act="silu",
upsample_block_type="InterpolateConv",
scaling_factor=0.41407,
)
ae.load_state_dict(state_dict)
print(ae)
RuntimeError: Error(s) in loading state_dict for AutoencoderDC:
Missing key(s) in state_dict: "encoder.project_in.weight", "encoder.project_in.bias", "encoder.stages.0.0.conv1.conv.weight", "encoder.stages.0.0.conv1.conv.bias", "encoder.stages.0.0.conv2.conv.weight", "encoder.stages.0.0.conv2.norm.weight", "encoder.stages.0.0.conv2.norm.bias", "encoder.stages.0.1.conv1.conv.weight", "encoder.stages.0.1.conv1.conv.bias", "encoder.stages.0.1.conv2.conv.weight", "encoder.stages.0.1.conv2.norm.weight", "encoder.stages.0.1.conv2.norm.bias", "encoder.stages.0.2.main.weight", "encoder.stages.0.2.main.bias", "encoder.stages.1.0.conv1.conv.weight", "encoder.stages.1.0.conv1.conv.bias", "encoder.stages.1.0.conv2.conv.weight", "encoder.stages.1.0.conv2.norm.weight", "encoder.stages.1.0.conv2.norm.bias", "encoder.stages.1.1.conv1.conv.weight", "encoder.stages.1.1.conv1.conv.bias", "encoder.stages.1.1.conv2.conv.weight", "encoder.stages.1.1.conv2.norm.weight", "encoder.stages.1.1.conv2.norm.bias", "encoder.stages.1.2.main.weight", "encoder.stages.1.2.main.bias", "encoder.stages.2.0.conv1.conv.weight", "encoder.stages.2.0.conv1.conv.bias", "encoder.stages.2.0.conv2.conv.weight", "encoder.stages.2.0.conv2.norm.weight", "encoder.stages.2.0.conv2.norm.bias", "encoder.stages.2.1.conv1.conv.weight", "encoder.stages.2.1.conv1.conv.bias", "encoder.stages.2.1.conv2.conv.weight", "encoder.stages.2.1.conv2.norm.weight", "encoder.stages.2.1.conv2.norm.bias", "encoder.stages.2.2.main.weight", "encoder.stages.2.2.main.bias", "encoder.stages.3.0.context_module.qkv.0.weight", "encoder.stages.3.0.context_module.aggreg.0.0.weight", "encoder.stages.3.0.context_module.aggreg.0.1.weight", "encoder.stages.3.0.context_module.proj.0.weight", "encoder.stages.3.0.context_module.proj.1.weight", "encoder.stages.3.0.context_module.proj.1.bias", "encoder.stages.3.0.local_module.inverted_conv.conv.weight", "encoder.stages.3.0.local_module.inverted_conv.conv.bias", "encoder.stages.3.0.local_module.depth_conv.conv.weight", "encoder.stages.3.0.local_module.depth_conv.conv.bias", "encoder.stages.3.0.local_module.point_conv.conv.weight", "encoder.stages.3.0.local_module.point_conv.norm.weight", "encoder.stages.3.0.local_module.point_conv.norm.bias", "encoder.stages.3.1.context_module.qkv.0.weight", "encoder.stages.3.1.context_module.aggreg.0.0.weight", "encoder.stages.3.1.context_module.aggreg.0.1.weight", "encoder.stages.3.1.context_module.proj.0.weight", "encoder.stages.3.1.context_module.proj.1.weight", "encoder.stages.3.1.context_module.proj.1.bias", "encoder.stages.3.1.local_module.inverted_conv.conv.weight", "encoder.stages.3.1.local_module.inverted_conv.conv.bias", "encoder.stages.3.1.local_module.depth_conv.conv.weight", "encoder.stages.3.1.local_module.depth_conv.conv.bias", "encoder.stages.3.1.local_module.point_conv.conv.weight", "encoder.stages.3.1.local_module.point_conv.norm.weight", "encoder.stages.3.1.local_module.point_conv.norm.bias", "encoder.stages.3.2.context_module.qkv.0.weight", "encoder.stages.3.2.context_module.aggreg.0.0.weight", "encoder.stages.3.2.context_module.aggreg.0.1.weight", "encoder.stages.3.2.context_module.proj.0.weight", "encoder.stages.3.2.context_module.proj.1.weight", "encoder.stages.3.2.context_module.proj.1.bias", "encoder.stages.3.2.local_module.inverted_conv.conv.weight", "encoder.stages.3.2.local_module.inverted_conv.conv.bias", "encoder.stages.3.2.local_module.depth_conv.conv.weight", "encoder.stages.3.2.local_module.depth_conv.conv.bias", "encoder.stages.3.2.local_module.point_conv.conv.weight", "encoder.stages.3.2.local_module.point_conv.norm.weight", "encoder.stages.3.2.local_module.point_conv.norm.bias", "encoder.stages.3.3.main.weight", "encoder.stages.3.3.main.bias", "encoder.stages.4.0.context_module.qkv.0.weight", "encoder.stages.4.0.context_module.aggreg.0.0.weight", "encoder.stages.4.0.context_module.aggreg.0.1.weight", "encoder.stages.4.0.context_module.proj.0.weight", "encoder.stages.4.0.context_module.proj.1.weight", "encoder.stages.4.0.context_module.proj.1.bias", "encoder.stages.4.0.local_module.inverted_conv.conv.weight", "encoder.stages.4.0.local_module.inverted_conv.conv.bias", "encoder.stages.4.0.local_module.depth_conv.conv.weight", "encoder.stages.4.0.local_module.depth_conv.conv.bias", "encoder.stages.4.0.local_module.point_conv.conv.weight", "encoder.stages.4.0.local_module.point_conv.norm.weight", "encoder.stages.4.0.local_module.point_conv.norm.bias", "encoder.stages.4.1.context_module.qkv.0.weight", "encoder.stages.4.1.context_module.aggreg.0.0.weight", "encoder.stages.4.1.context_module.aggreg.0.1.weight", "encoder.stages.4.1.context_module.proj.0.weight", "encoder.stages.4.1.context_module.proj.1.weight", "encoder.stages.4.1.context_module.proj.1.bias", "encoder.stages.4.1.local_module.inverted_conv.conv.weight", "encoder.stages.4.1.local_module.inverted_conv.conv.bias", "encoder.stages.4.1.local_module.depth_conv.conv.weight", "encoder.stages.4.1.local_module.depth_conv.conv.bias", "encoder.stages.4.1.local_module.point_conv.conv.weight", "encoder.stages.4.1.local_module.point_conv.norm.weight", "encoder.stages.4.1.local_module.point_conv.norm.bias", "encoder.stages.4.2.context_module.qkv.0.weight", "encoder.stages.4.2.context_module.aggreg.0.0.weight", "encoder.stages.4.2.context_module.aggreg.0.1.weight", "encoder.stages.4.2.context_module.proj.0.weight", "encoder.stages.4.2.context_module.proj.1.weight", "encoder.stages.4.2.context_module.proj.1.bias", "encoder.stages.4.2.local_module.inverted_conv.conv.weight", "encoder.stages.4.2.local_module.inverted_conv.conv.bias", "encoder.stages.4.2.local_module.depth_conv.conv.weight", "encoder.stages.4.2.local_module.depth_conv.conv.bias", "encoder.stages.4.2.local_module.point_conv.conv.weight", "encoder.stages.4.2.local_module.point_conv.norm.weight", "encoder.stages.4.2.local_module.point_conv.norm.bias", "encoder.stages.4.3.main.weight", "encoder.stages.4.3.main.bias", "encoder.stages.5.0.context_module.qkv.0.weight", "encoder.stages.5.0.context_module.aggreg.0.0.weight", "encoder.stages.5.0.context_module.aggreg.0.1.weight", "encoder.stages.5.0.context_module.proj.0.weight", "encoder.stages.5.0.context_module.proj.1.weight", "encoder.stages.5.0.context_module.proj.1.bias", "encoder.stages.5.0.local_module.inverted_conv.conv.weight", "encoder.stages.5.0.local_module.inverted_conv.conv.bias", "encoder.stages.5.0.local_module.depth_conv.conv.weight", "encoder.stages.5.0.local_module.depth_conv.conv.bias", "encoder.stages.5.0.local_module.point_conv.conv.weight", "encoder.stages.5.0.local_module.point_conv.norm.weight", "encoder.stages.5.0.local_module.point_conv.norm.bias", "encoder.stages.5.1.context_module.qkv.0.weight", "encoder.stages.5.1.context_module.aggreg.0.0.weight", "encoder.stages.5.1.context_module.aggreg.0.1.weight", "encoder.stages.5.1.context_module.proj.0.weight", "encoder.stages.5.1.context_module.proj.1.weight", "encoder.stages.5.1.context_module.proj.1.bias", "encoder.stages.5.1.local_module.inverted_conv.conv.weight", "encoder.stages.5.1.local_module.inverted_conv.conv.bias", "encoder.stages.5.1.local_module.depth_conv.conv.weight", "encoder.stages.5.1.local_module.depth_conv.conv.bias", "encoder.stages.5.1.local_module.point_conv.conv.weight", "encoder.stages.5.1.local_module.point_conv.norm.weight", "encoder.stages.5.1.local_module.point_conv.norm.bias", "encoder.stages.5.2.context_module.qkv.0.weight", "encoder.stages.5.2.context_module.aggreg.0.0.weight", "encoder.stages.5.2.context_module.aggreg.0.1.weight", "encoder.stages.5.2.context_module.proj.0.weight", "encoder.stages.5.2.context_module.proj.1.weight", "encoder.stages.5.2.context_module.proj.1.bias", "encoder.stages.5.2.local_module.inverted_conv.conv.weight", "encoder.stages.5.2.local_module.inverted_conv.conv.bias", "encoder.stages.5.2.local_module.depth_conv.conv.weight", "encoder.stages.5.2.local_module.depth_conv.conv.bias", "encoder.stages.5.2.local_module.point_conv.conv.weight", "encoder.stages.5.2.local_module.point_conv.norm.weight", "encoder.stages.5.2.local_module.point_conv.norm.bias", "encoder.project_out.main.0.conv.weight", "encoder.project_out.main.0.conv.bias", "decoder.stages.0.0.main.conv.weight", "decoder.stages.0.0.main.conv.bias", "decoder.stages.0.1.conv1.conv.weight", "decoder.stages.0.1.conv1.conv.bias", "decoder.stages.0.1.conv2.conv.weight", "decoder.stages.0.1.conv2.norm.weight", "decoder.stages.0.1.conv2.norm.bias", "decoder.stages.0.2.conv1.conv.weight", "decoder.stages.0.2.conv1.conv.bias", "decoder.stages.0.2.conv2.conv.weight", "decoder.stages.0.2.conv2.norm.weight", "decoder.stages.0.2.conv2.norm.bias", "decoder.stages.0.3.conv1.conv.weight", "decoder.stages.0.3.conv1.conv.bias", "decoder.stages.0.3.conv2.conv.weight", "decoder.stages.0.3.conv2.norm.weight", "decoder.stages.0.3.conv2.norm.bias", "decoder.stages.1.0.main.conv.weight", "decoder.stages.1.0.main.conv.bias", "decoder.stages.1.1.conv1.conv.weight", "decoder.stages.1.1.conv1.conv.bias", "decoder.stages.1.1.conv2.conv.weight", "decoder.stages.1.1.conv2.norm.weight", "decoder.stages.1.1.conv2.norm.bias", "decoder.stages.1.2.conv1.conv.weight", "decoder.stages.1.2.conv1.conv.bias", "decoder.stages.1.2.conv2.conv.weight", "decoder.stages.1.2.conv2.norm.weight", "decoder.stages.1.2.conv2.norm.bias", "decoder.stages.1.3.conv1.conv.weight", "decoder.stages.1.3.conv1.conv.bias", "decoder.stages.1.3.conv2.conv.weight", "decoder.stages.1.3.conv2.norm.weight", "decoder.stages.1.3.conv2.norm.bias", "decoder.stages.2.0.main.conv.weight", "decoder.stages.2.0.main.conv.bias", "decoder.stages.2.1.conv1.conv.weight", "decoder.stages.2.1.conv1.conv.bias", "decoder.stages.2.1.conv2.conv.weight", "decoder.stages.2.1.conv2.norm.weight", "decoder.stages.2.1.conv2.norm.bias", "decoder.stages.2.2.conv1.conv.weight", "decoder.stages.2.2.conv1.conv.bias", "decoder.stages.2.2.conv2.conv.weight", "decoder.stages.2.2.conv2.norm.weight", "decoder.stages.2.2.conv2.norm.bias", "decoder.stages.2.3.conv1.conv.weight", "decoder.stages.2.3.conv1.conv.bias", "decoder.stages.2.3.conv2.conv.weight", "decoder.stages.2.3.conv2.norm.weight", "decoder.stages.2.3.conv2.norm.bias", "decoder.stages.3.0.main.conv.weight", "decoder.stages.3.0.main.conv.bias", "decoder.stages.3.1.context_module.qkv.0.weight", "decoder.stages.3.1.context_module.aggreg.0.0.weight", "decoder.stages.3.1.context_module.aggreg.0.1.weight", "decoder.stages.3.1.context_module.proj.0.weight", "decoder.stages.3.1.context_module.proj.1.weight", "decoder.stages.3.1.context_module.proj.1.bias", "decoder.stages.3.1.local_module.inverted_conv.conv.weight", "decoder.stages.3.1.local_module.inverted_conv.conv.bias", "decoder.stages.3.1.local_module.depth_conv.conv.weight", "decoder.stages.3.1.local_module.depth_conv.conv.bias", "decoder.stages.3.1.local_module.point_conv.conv.weight", "decoder.stages.3.1.local_module.point_conv.norm.weight", "decoder.stages.3.1.local_module.point_conv.norm.bias", "decoder.stages.3.2.context_module.qkv.0.weight", "decoder.stages.3.2.context_module.aggreg.0.0.weight", "decoder.stages.3.2.context_module.aggreg.0.1.weight", "decoder.stages.3.2.context_module.proj.0.weight", "decoder.stages.3.2.context_module.proj.1.weight", "decoder.stages.3.2.context_module.proj.1.bias", "decoder.stages.3.2.local_module.inverted_conv.conv.weight", "decoder.stages.3.2.local_module.inverted_conv.conv.bias", "decoder.stages.3.2.local_module.depth_conv.conv.weight", "decoder.stages.3.2.local_module.depth_conv.conv.bias", "decoder.stages.3.2.local_module.point_conv.conv.weight", "decoder.stages.3.2.local_module.point_conv.norm.weight", "decoder.stages.3.2.local_module.point_conv.norm.bias", "decoder.stages.3.3.context_module.qkv.0.weight", "decoder.stages.3.3.context_module.aggreg.0.0.weight", "decoder.stages.3.3.context_module.aggreg.0.1.weight", "decoder.stages.3.3.context_module.proj.0.weight", "decoder.stages.3.3.context_module.proj.1.weight", "decoder.stages.3.3.context_module.proj.1.bias", "decoder.stages.3.3.local_module.inverted_conv.conv.weight", "decoder.stages.3.3.local_module.inverted_conv.conv.bias", "decoder.stages.3.3.local_module.depth_conv.conv.weight", "decoder.stages.3.3.local_module.depth_conv.conv.bias", "decoder.stages.3.3.local_module.point_conv.conv.weight", "decoder.stages.3.3.local_module.point_conv.norm.weight", "decoder.stages.3.3.local_module.point_conv.norm.bias", "decoder.stages.4.0.main.conv.weight", "decoder.stages.4.0.main.conv.bias", "decoder.stages.4.1.context_module.qkv.0.weight", "decoder.stages.4.1.context_module.aggreg.0.0.weight", "decoder.stages.4.1.context_module.aggreg.0.1.weight", "decoder.stages.4.1.context_module.proj.0.weight", "decoder.stages.4.1.context_module.proj.1.weight", "decoder.stages.4.1.context_module.proj.1.bias", "decoder.stages.4.1.local_module.inverted_conv.conv.weight", "decoder.stages.4.1.local_module.inverted_conv.conv.bias", "decoder.stages.4.1.local_module.depth_conv.conv.weight", "decoder.stages.4.1.local_module.depth_conv.conv.bias", "decoder.stages.4.1.local_module.point_conv.conv.weight", "decoder.stages.4.1.local_module.point_conv.norm.weight", "decoder.stages.4.1.local_module.point_conv.norm.bias", "decoder.stages.4.2.context_module.qkv.0.weight", "decoder.stages.4.2.context_module.aggreg.0.0.weight", "decoder.stages.4.2.context_module.aggreg.0.1.weight", "decoder.stages.4.2.context_module.proj.0.weight", "decoder.stages.4.2.context_module.proj.1.weight", "decoder.stages.4.2.context_module.proj.1.bias", "decoder.stages.4.2.local_module.inverted_conv.conv.weight", "decoder.stages.4.2.local_module.inverted_conv.conv.bias", "decoder.stages.4.2.local_module.depth_conv.conv.weight", "decoder.stages.4.2.local_module.depth_conv.conv.bias", "decoder.stages.4.2.local_module.point_conv.conv.weight", "decoder.stages.4.2.local_module.point_conv.norm.weight", "decoder.stages.4.2.local_module.point_conv.norm.bias", "decoder.stages.4.3.context_module.qkv.0.weight", "decoder.stages.4.3.context_module.aggreg.0.0.weight", "decoder.stages.4.3.context_module.aggreg.0.1.weight", "decoder.stages.4.3.context_module.proj.0.weight", "decoder.stages.4.3.context_module.proj.1.weight", "decoder.stages.4.3.context_module.proj.1.bias", "decoder.stages.4.3.local_module.inverted_conv.conv.weight", "decoder.stages.4.3.local_module.inverted_conv.conv.bias", "decoder.stages.4.3.local_module.depth_conv.conv.weight", "decoder.stages.4.3.local_module.depth_conv.conv.bias", "decoder.stages.4.3.local_module.point_conv.conv.weight", "decoder.stages.4.3.local_module.point_conv.norm.weight", "decoder.stages.4.3.local_module.point_conv.norm.bias", "decoder.stages.5.0.context_module.qkv.0.weight", "decoder.stages.5.0.context_module.aggreg.0.0.weight", "decoder.stages.5.0.context_module.aggreg.0.1.weight", "decoder.stages.5.0.context_module.proj.0.weight", "decoder.stages.5.0.context_module.proj.1.weight", "decoder.stages.5.0.context_module.proj.1.bias", "decoder.stages.5.0.local_module.inverted_conv.conv.weight", "decoder.stages.5.0.local_module.inverted_conv.conv.bias", "decoder.stages.5.0.local_module.depth_conv.conv.weight", "decoder.stages.5.0.local_module.depth_conv.conv.bias", "decoder.stages.5.0.local_module.point_conv.conv.weight", "decoder.stages.5.0.local_module.point_conv.norm.weight", "decoder.stages.5.0.local_module.point_conv.norm.bias", "decoder.stages.5.1.context_module.qkv.0.weight", "decoder.stages.5.1.context_module.aggreg.0.0.weight", "decoder.stages.5.1.context_module.aggreg.0.1.weight", "decoder.stages.5.1.context_module.proj.0.weight", "decoder.stages.5.1.context_module.proj.1.weight", "decoder.stages.5.1.context_module.proj.1.bias", "decoder.stages.5.1.local_module.inverted_conv.conv.weight", "decoder.stages.5.1.local_module.inverted_conv.conv.bias", "decoder.stages.5.1.local_module.depth_conv.conv.weight", "decoder.stages.5.1.local_module.depth_conv.conv.bias", "decoder.stages.5.1.local_module.point_conv.conv.weight", "decoder.stages.5.1.local_module.point_conv.norm.weight", "decoder.stages.5.1.local_module.point_conv.norm.bias", "decoder.stages.5.2.context_module.qkv.0.weight", "decoder.stages.5.2.context_module.aggreg.0.0.weight", "decoder.stages.5.2.context_module.aggreg.0.1.weight", "decoder.stages.5.2.context_module.proj.0.weight", "decoder.stages.5.2.context_module.proj.1.weight", "decoder.stages.5.2.context_module.proj.1.bias", "decoder.stages.5.2.local_module.inverted_conv.conv.weight", "decoder.stages.5.2.local_module.inverted_conv.conv.bias", "decoder.stages.5.2.local_module.depth_conv.conv.weight", "decoder.stages.5.2.local_module.depth_conv.conv.bias", "decoder.stages.5.2.local_module.point_conv.conv.weight", "decoder.stages.5.2.local_module.point_conv.norm.weight", "decoder.stages.5.2.local_module.point_conv.norm.bias", "decoder.project_out.0.weight", "decoder.project_out.0.bias", "decoder.project_out.2.conv.weight", "decoder.project_out.2.conv.bias".
Unexpected key(s) in state_dict: "encoder.project_in.conv.bias", "encoder.project_in.conv.weight", "encoder.stages.0.op_list.0.main.conv1.conv.bias", "encoder.stages.0.op_list.0.main.conv1.conv.weight", "encoder.stages.0.op_list.0.main.conv2.conv.weight", "encoder.stages.0.op_list.0.main.conv2.norm.bias", "encoder.stages.0.op_list.0.main.conv2.norm.weight", "encoder.stages.0.op_list.1.main.conv1.conv.bias", "encoder.stages.0.op_list.1.main.conv1.conv.weight", "encoder.stages.0.op_list.1.main.conv2.conv.weight", "encoder.stages.0.op_list.1.main.conv2.norm.bias", "encoder.stages.0.op_list.1.main.conv2.norm.weight", "encoder.stages.0.op_list.2.main.conv.bias", "encoder.stages.0.op_list.2.main.conv.weight", "encoder.stages.1.op_list.0.main.conv1.conv.bias", "encoder.stages.1.op_list.0.main.conv1.conv.weight", "encoder.stages.1.op_list.0.main.conv2.conv.weight", "encoder.stages.1.op_list.0.main.conv2.norm.bias", "encoder.stages.1.op_list.0.main.conv2.norm.weight", "encoder.stages.1.op_list.1.main.conv1.conv.bias", "encoder.stages.1.op_list.1.main.conv1.conv.weight", "encoder.stages.1.op_list.1.main.conv2.conv.weight", "encoder.stages.1.op_list.1.main.conv2.norm.bias", "encoder.stages.1.op_list.1.main.conv2.norm.weight", "encoder.stages.1.op_list.2.main.conv.bias", "encoder.stages.1.op_list.2.main.conv.weight", "encoder.stages.2.op_list.0.main.conv1.conv.bias", "encoder.stages.2.op_list.0.main.conv1.conv.weight", "encoder.stages.2.op_list.0.main.conv2.conv.weight", "encoder.stages.2.op_list.0.main.conv2.norm.bias", "encoder.stages.2.op_list.0.main.conv2.norm.weight", "encoder.stages.2.op_list.1.main.conv1.conv.bias", "encoder.stages.2.op_list.1.main.conv1.conv.weight", "encoder.stages.2.op_list.1.main.conv2.conv.weight", "encoder.stages.2.op_list.1.main.conv2.norm.bias", "encoder.stages.2.op_list.1.main.conv2.norm.weight", "encoder.stages.2.op_list.2.main.conv.bias", "encoder.stages.2.op_list.2.main.conv.weight", "encoder.stages.3.op_list.0.context_module.main.aggreg.0.0.weight", "encoder.stages.3.op_list.0.context_module.main.aggreg.0.1.weight", "encoder.stages.3.op_list.0.context_module.main.proj.conv.weight", "encoder.stages.3.op_list.0.context_module.main.proj.norm.bias", "encoder.stages.3.op_list.0.context_module.main.proj.norm.weight", "encoder.stages.3.op_list.0.context_module.main.qkv.conv.weight", "encoder.stages.3.op_list.0.local_module.main.depth_conv.conv.bias", "encoder.stages.3.op_list.0.local_module.main.depth_conv.conv.weight", "encoder.stages.3.op_list.0.local_module.main.inverted_conv.conv.bias", "encoder.stages.3.op_list.0.local_module.main.inverted_conv.conv.weight", "encoder.stages.3.op_list.0.local_module.main.point_conv.conv.weight", "encoder.stages.3.op_list.0.local_module.main.point_conv.norm.bias", "encoder.stages.3.op_list.0.local_module.main.point_conv.norm.weight", "encoder.stages.3.op_list.1.context_module.main.aggreg.0.0.weight", "encoder.stages.3.op_list.1.context_module.main.aggreg.0.1.weight", "encoder.stages.3.op_list.1.context_module.main.proj.conv.weight", "encoder.stages.3.op_list.1.context_module.main.proj.norm.bias", "encoder.stages.3.op_list.1.context_module.main.proj.norm.weight", "encoder.stages.3.op_list.1.context_module.main.qkv.conv.weight", "encoder.stages.3.op_list.1.local_module.main.depth_conv.conv.bias", "encoder.stages.3.op_list.1.local_module.main.depth_conv.conv.weight", "encoder.stages.3.op_list.1.local_module.main.inverted_conv.conv.bias", "encoder.stages.3.op_list.1.local_module.main.inverted_conv.conv.weight", "encoder.stages.3.op_list.1.local_module.main.point_conv.conv.weight", "encoder.stages.3.op_list.1.local_module.main.point_conv.norm.bias", "encoder.stages.3.op_list.1.local_module.main.point_conv.norm.weight", "encoder.stages.3.op_list.2.context_module.main.aggreg.0.0.weight", "encoder.stages.3.op_list.2.context_module.main.aggreg.0.1.weight", "encoder.stages.3.op_list.2.context_module.main.proj.conv.weight", "encoder.stages.3.op_list.2.context_module.main.proj.norm.bias", "encoder.stages.3.op_list.2.context_module.main.proj.norm.weight", "encoder.stages.3.op_list.2.context_module.main.qkv.conv.weight", "encoder.stages.3.op_list.2.local_module.main.depth_conv.conv.bias", "encoder.stages.3.op_list.2.local_module.main.depth_conv.conv.weight", "encoder.stages.3.op_list.2.local_module.main.inverted_conv.conv.bias", "encoder.stages.3.op_list.2.local_module.main.inverted_conv.conv.weight", "encoder.stages.3.op_list.2.local_module.main.point_conv.conv.weight", "encoder.stages.3.op_list.2.local_module.main.point_conv.norm.bias", "encoder.stages.3.op_list.2.local_module.main.point_conv.norm.weight", "encoder.stages.3.op_list.3.main.conv.bias", "encoder.stages.3.op_list.3.main.conv.weight", "encoder.stages.4.op_list.0.context_module.main.aggreg.0.0.weight", "encoder.stages.4.op_list.0.context_module.main.aggreg.0.1.weight", "encoder.stages.4.op_list.0.context_module.main.proj.conv.weight", "encoder.stages.4.op_list.0.context_module.main.proj.norm.bias", "encoder.stages.4.op_list.0.context_module.main.proj.norm.weight", "encoder.stages.4.op_list.0.context_module.main.qkv.conv.weight", "encoder.stages.4.op_list.0.local_module.main.depth_conv.conv.bias", "encoder.stages.4.op_list.0.local_module.main.depth_conv.conv.weight", "encoder.stages.4.op_list.0.local_module.main.inverted_conv.conv.bias", "encoder.stages.4.op_list.0.local_module.main.inverted_conv.conv.weight", "encoder.stages.4.op_list.0.local_module.main.point_conv.conv.weight", "encoder.stages.4.op_list.0.local_module.main.point_conv.norm.bias", "encoder.stages.4.op_list.0.local_module.main.point_conv.norm.weight", "encoder.stages.4.op_list.1.context_module.main.aggreg.0.0.weight", "encoder.stages.4.op_list.1.context_module.main.aggreg.0.1.weight", "encoder.stages.4.op_list.1.context_module.main.proj.conv.weight", "encoder.stages.4.op_list.1.context_module.main.proj.norm.bias", "encoder.stages.4.op_list.1.context_module.main.proj.norm.weight", "encoder.stages.4.op_list.1.context_module.main.qkv.conv.weight", "encoder.stages.4.op_list.1.local_module.main.depth_conv.conv.bias", "encoder.stages.4.op_list.1.local_module.main.depth_conv.conv.weight", "encoder.stages.4.op_list.1.local_module.main.inverted_conv.conv.bias", "encoder.stages.4.op_list.1.local_module.main.inverted_conv.conv.weight", "encoder.stages.4.op_list.1.local_module.main.point_conv.conv.weight", "encoder.stages.4.op_list.1.local_module.main.point_conv.norm.bias", "encoder.stages.4.op_list.1.local_module.main.point_conv.norm.weight", "encoder.stages.4.op_list.2.context_module.main.aggreg.0.0.weight", "encoder.stages.4.op_list.2.context_module.main.aggreg.0.1.weight", "encoder.stages.4.op_list.2.context_module.main.proj.conv.weight", "encoder.stages.4.op_list.2.context_module.main.proj.norm.bias", "encoder.stages.4.op_list.2.context_module.main.proj.norm.weight", "encoder.stages.4.op_list.2.context_module.main.qkv.conv.weight", "encoder.stages.4.op_list.2.local_module.main.depth_conv.conv.bias", "encoder.stages.4.op_list.2.local_module.main.depth_conv.conv.weight", "encoder.stages.4.op_list.2.local_module.main.inverted_conv.conv.bias", "encoder.stages.4.op_list.2.local_module.main.inverted_conv.conv.weight", "encoder.stages.4.op_list.2.local_module.main.point_conv.conv.weight", "encoder.stages.4.op_list.2.local_module.main.point_conv.norm.bias", "encoder.stages.4.op_list.2.local_module.main.point_conv.norm.weight", "encoder.stages.4.op_list.3.main.conv.bias", "encoder.stages.4.op_list.3.main.conv.weight", "encoder.stages.5.op_list.0.context_module.main.aggreg.0.0.weight", "encoder.stages.5.op_list.0.context_module.main.aggreg.0.1.weight", "encoder.stages.5.op_list.0.context_module.main.proj.conv.weight", "encoder.stages.5.op_list.0.context_module.main.proj.norm.bias", "encoder.stages.5.op_list.0.context_module.main.proj.norm.weight", "encoder.stages.5.op_list.0.context_module.main.qkv.conv.weight", "encoder.stages.5.op_list.0.local_module.main.depth_conv.conv.bias", "encoder.stages.5.op_list.0.local_module.main.depth_conv.conv.weight", "encoder.stages.5.op_list.0.local_module.main.inverted_conv.conv.bias", "encoder.stages.5.op_list.0.local_module.main.inverted_conv.conv.weight", "encoder.stages.5.op_list.0.local_module.main.point_conv.conv.weight", "encoder.stages.5.op_list.0.local_module.main.point_conv.norm.bias", "encoder.stages.5.op_list.0.local_module.main.point_conv.norm.weight", "encoder.stages.5.op_list.1.context_module.main.aggreg.0.0.weight", "encoder.stages.5.op_list.1.context_module.main.aggreg.0.1.weight", "encoder.stages.5.op_list.1.context_module.main.proj.conv.weight", "encoder.stages.5.op_list.1.context_module.main.proj.norm.bias", "encoder.stages.5.op_list.1.context_module.main.proj.norm.weight", "encoder.stages.5.op_list.1.context_module.main.qkv.conv.weight", "encoder.stages.5.op_list.1.local_module.main.depth_conv.conv.bias", "encoder.stages.5.op_list.1.local_module.main.depth_conv.conv.weight", "encoder.stages.5.op_list.1.local_module.main.inverted_conv.conv.bias", "encoder.stages.5.op_list.1.local_module.main.inverted_conv.conv.weight", "encoder.stages.5.op_list.1.local_module.main.point_conv.conv.weight", "encoder.stages.5.op_list.1.local_module.main.point_conv.norm.bias", "encoder.stages.5.op_list.1.local_module.main.point_conv.norm.weight", "encoder.stages.5.op_list.2.context_module.main.aggreg.0.0.weight", "encoder.stages.5.op_list.2.context_module.main.aggreg.0.1.weight", "encoder.stages.5.op_list.2.context_module.main.proj.conv.weight", "encoder.stages.5.op_list.2.context_module.main.proj.norm.bias", "encoder.stages.5.op_list.2.context_module.main.proj.norm.weight", "encoder.stages.5.op_list.2.context_module.main.qkv.conv.weight", "encoder.stages.5.op_list.2.local_module.main.depth_conv.conv.bias", "encoder.stages.5.op_list.2.local_module.main.depth_conv.conv.weight", "encoder.stages.5.op_list.2.local_module.main.inverted_conv.conv.bias", "encoder.stages.5.op_list.2.local_module.main.inverted_conv.conv.weight", "encoder.stages.5.op_list.2.local_module.main.point_conv.conv.weight", "encoder.stages.5.op_list.2.local_module.main.point_conv.norm.bias", "encoder.stages.5.op_list.2.local_module.main.point_conv.norm.weight", "encoder.project_out.main.op_list.0.conv.bias", "encoder.project_out.main.op_list.0.conv.weight", "decoder.stages.0.op_list.0.main.conv.conv.bias", "decoder.stages.0.op_list.0.main.conv.conv.weight", "decoder.stages.0.op_list.1.main.conv1.conv.bias", "decoder.stages.0.op_list.1.main.conv1.conv.weight", "decoder.stages.0.op_list.1.main.conv2.conv.weight", "decoder.stages.0.op_list.1.main.conv2.norm.bias", "decoder.stages.0.op_list.1.main.conv2.norm.weight", "decoder.stages.0.op_list.2.main.conv1.conv.bias", "decoder.stages.0.op_list.2.main.conv1.conv.weight", "decoder.stages.0.op_list.2.main.conv2.conv.weight", "decoder.stages.0.op_list.2.main.conv2.norm.bias", "decoder.stages.0.op_list.2.main.conv2.norm.weight", "decoder.stages.0.op_list.3.main.conv1.conv.bias", "decoder.stages.0.op_list.3.main.conv1.conv.weight", "decoder.stages.0.op_list.3.main.conv2.conv.weight", "decoder.stages.0.op_list.3.main.conv2.norm.bias", "decoder.stages.0.op_list.3.main.conv2.norm.weight", "decoder.stages.1.op_list.0.main.conv.conv.bias", "decoder.stages.1.op_list.0.main.conv.conv.weight", "decoder.stages.1.op_list.1.main.conv1.conv.bias", "decoder.stages.1.op_list.1.main.conv1.conv.weight", "decoder.stages.1.op_list.1.main.conv2.conv.weight", "decoder.stages.1.op_list.1.main.conv2.norm.bias", "decoder.stages.1.op_list.1.main.conv2.norm.weight", "decoder.stages.1.op_list.2.main.conv1.conv.bias", "decoder.stages.1.op_list.2.main.conv1.conv.weight", "decoder.stages.1.op_list.2.main.conv2.conv.weight", "decoder.stages.1.op_list.2.main.conv2.norm.bias", "decoder.stages.1.op_list.2.main.conv2.norm.weight", "decoder.stages.1.op_list.3.main.conv1.conv.bias", "decoder.stages.1.op_list.3.main.conv1.conv.weight", "decoder.stages.1.op_list.3.main.conv2.conv.weight", "decoder.stages.1.op_list.3.main.conv2.norm.bias", "decoder.stages.1.op_list.3.main.conv2.norm.weight", "decoder.stages.2.op_list.0.main.conv.conv.bias", "decoder.stages.2.op_list.0.main.conv.conv.weight", "decoder.stages.2.op_list.1.main.conv1.conv.bias", "decoder.stages.2.op_list.1.main.conv1.conv.weight", "decoder.stages.2.op_list.1.main.conv2.conv.weight", "decoder.stages.2.op_list.1.main.conv2.norm.bias", "decoder.stages.2.op_list.1.main.conv2.norm.weight", "decoder.stages.2.op_list.2.main.conv1.conv.bias", "decoder.stages.2.op_list.2.main.conv1.conv.weight", "decoder.stages.2.op_list.2.main.conv2.conv.weight", "decoder.stages.2.op_list.2.main.conv2.norm.bias", "decoder.stages.2.op_list.2.main.conv2.norm.weight", "decoder.stages.2.op_list.3.main.conv1.conv.bias", "decoder.stages.2.op_list.3.main.conv1.conv.weight", "decoder.stages.2.op_list.3.main.conv2.conv.weight", "decoder.stages.2.op_list.3.main.conv2.norm.bias", "decoder.stages.2.op_list.3.main.conv2.norm.weight", "decoder.stages.3.op_list.0.main.conv.conv.bias", "decoder.stages.3.op_list.0.main.conv.conv.weight", "decoder.stages.3.op_list.1.context_module.main.aggreg.0.0.weight", "decoder.stages.3.op_list.1.context_module.main.aggreg.0.1.weight", "decoder.stages.3.op_list.1.context_module.main.proj.conv.weight", "decoder.stages.3.op_list.1.context_module.main.proj.norm.bias", "decoder.stages.3.op_list.1.context_module.main.proj.norm.weight", "decoder.stages.3.op_list.1.context_module.main.qkv.conv.weight", "decoder.stages.3.op_list.1.local_module.main.depth_conv.conv.bias", "decoder.stages.3.op_list.1.local_module.main.depth_conv.conv.weight", "decoder.stages.3.op_list.1.local_module.main.inverted_conv.conv.bias", "decoder.stages.3.op_list.1.local_module.main.inverted_conv.conv.weight", "decoder.stages.3.op_list.1.local_module.main.point_conv.conv.weight", "decoder.stages.3.op_list.1.local_module.main.point_conv.norm.bias", "decoder.stages.3.op_list.1.local_module.main.point_conv.norm.weight", "decoder.stages.3.op_list.2.context_module.main.aggreg.0.0.weight", "decoder.stages.3.op_list.2.context_module.main.aggreg.0.1.weight", "decoder.stages.3.op_list.2.context_module.main.proj.conv.weight", "decoder.stages.3.op_list.2.context_module.main.proj.norm.bias", "decoder.stages.3.op_list.2.context_module.main.proj.norm.weight", "decoder.stages.3.op_list.2.context_module.main.qkv.conv.weight", "decoder.stages.3.op_list.2.local_module.main.depth_conv.conv.bias", "decoder.stages.3.op_list.2.local_module.main.depth_conv.conv.weight", "decoder.stages.3.op_list.2.local_module.main.inverted_conv.conv.bias", "decoder.stages.3.op_list.2.local_module.main.inverted_conv.conv.weight", "decoder.stages.3.op_list.2.local_module.main.point_conv.conv.weight", "decoder.stages.3.op_list.2.local_module.main.point_conv.norm.bias", "decoder.stages.3.op_list.2.local_module.main.point_conv.norm.weight", "decoder.stages.3.op_list.3.context_module.main.aggreg.0.0.weight", "decoder.stages.3.op_list.3.context_module.main.aggreg.0.1.weight", "decoder.stages.3.op_list.3.context_module.main.proj.conv.weight", "decoder.stages.3.op_list.3.context_module.main.proj.norm.bias", "decoder.stages.3.op_list.3.context_module.main.proj.norm.weight", "decoder.stages.3.op_list.3.context_module.main.qkv.conv.weight", "decoder.stages.3.op_list.3.local_module.main.depth_conv.conv.bias", "decoder.stages.3.op_list.3.local_module.main.depth_conv.conv.weight", "decoder.stages.3.op_list.3.local_module.main.inverted_conv.conv.bias", "decoder.stages.3.op_list.3.local_module.main.inverted_conv.conv.weight", "decoder.stages.3.op_list.3.local_module.main.point_conv.conv.weight", "decoder.stages.3.op_list.3.local_module.main.point_conv.norm.bias", "decoder.stages.3.op_list.3.local_module.main.point_conv.norm.weight", "decoder.stages.4.op_list.0.main.conv.conv.bias", "decoder.stages.4.op_list.0.main.conv.conv.weight", "decoder.stages.4.op_list.1.context_module.main.aggreg.0.0.weight", "decoder.stages.4.op_list.1.context_module.main.aggreg.0.1.weight", "decoder.stages.4.op_list.1.context_module.main.proj.conv.weight", "decoder.stages.4.op_list.1.context_module.main.proj.norm.bias", "decoder.stages.4.op_list.1.context_module.main.proj.norm.weight", "decoder.stages.4.op_list.1.context_module.main.qkv.conv.weight", "decoder.stages.4.op_list.1.local_module.main.depth_conv.conv.bias", "decoder.stages.4.op_list.1.local_module.main.depth_conv.conv.weight", "decoder.stages.4.op_list.1.local_module.main.inverted_conv.conv.bias", "decoder.stages.4.op_list.1.local_module.main.inverted_conv.conv.weight", "decoder.stages.4.op_list.1.local_module.main.point_conv.conv.weight", "decoder.stages.4.op_list.1.local_module.main.point_conv.norm.bias", "decoder.stages.4.op_list.1.local_module.main.point_conv.norm.weight", "decoder.stages.4.op_list.2.context_module.main.aggreg.0.0.weight", "decoder.stages.4.op_list.2.context_module.main.aggreg.0.1.weight", "decoder.stages.4.op_list.2.context_module.main.proj.conv.weight", "decoder.stages.4.op_list.2.context_module.main.proj.norm.bias", "decoder.stages.4.op_list.2.context_module.main.proj.norm.weight", "decoder.stages.4.op_list.2.context_module.main.qkv.conv.weight", "decoder.stages.4.op_list.2.local_module.main.depth_conv.conv.bias", "decoder.stages.4.op_list.2.local_module.main.depth_conv.conv.weight", "decoder.stages.4.op_list.2.local_module.main.inverted_conv.conv.bias", "decoder.stages.4.op_list.2.local_module.main.inverted_conv.conv.weight", "decoder.stages.4.op_list.2.local_module.main.point_conv.conv.weight", "decoder.stages.4.op_list.2.local_module.main.point_conv.norm.bias", "decoder.stages.4.op_list.2.local_module.main.point_conv.norm.weight", "decoder.stages.4.op_list.3.context_module.main.aggreg.0.0.weight", "decoder.stages.4.op_list.3.context_module.main.aggreg.0.1.weight", "decoder.stages.4.op_list.3.context_module.main.proj.conv.weight", "decoder.stages.4.op_list.3.context_module.main.proj.norm.bias", "decoder.stages.4.op_list.3.context_module.main.proj.norm.weight", "decoder.stages.4.op_list.3.context_module.main.qkv.conv.weight", "decoder.stages.4.op_list.3.local_module.main.depth_conv.conv.bias", "decoder.stages.4.op_list.3.local_module.main.depth_conv.conv.weight", "decoder.stages.4.op_list.3.local_module.main.inverted_conv.conv.bias", "decoder.stages.4.op_list.3.local_module.main.inverted_conv.conv.weight", "decoder.stages.4.op_list.3.local_module.main.point_conv.conv.weight", "decoder.stages.4.op_list.3.local_module.main.point_conv.norm.bias", "decoder.stages.4.op_list.3.local_module.main.point_conv.norm.weight", "decoder.stages.5.op_list.0.context_module.main.aggreg.0.0.weight", "decoder.stages.5.op_list.0.context_module.main.aggreg.0.1.weight", "decoder.stages.5.op_list.0.context_module.main.proj.conv.weight", "decoder.stages.5.op_list.0.context_module.main.proj.norm.bias", "decoder.stages.5.op_list.0.context_module.main.proj.norm.weight", "decoder.stages.5.op_list.0.context_module.main.qkv.conv.weight", "decoder.stages.5.op_list.0.local_module.main.depth_conv.conv.bias", "decoder.stages.5.op_list.0.local_module.main.depth_conv.conv.weight", "decoder.stages.5.op_list.0.local_module.main.inverted_conv.conv.bias", "decoder.stages.5.op_list.0.local_module.main.inverted_conv.conv.weight", "decoder.stages.5.op_list.0.local_module.main.point_conv.conv.weight", "decoder.stages.5.op_list.0.local_module.main.point_conv.norm.bias", "decoder.stages.5.op_list.0.local_module.main.point_conv.norm.weight", "decoder.stages.5.op_list.1.context_module.main.aggreg.0.0.weight", "decoder.stages.5.op_list.1.context_module.main.aggreg.0.1.weight", "decoder.stages.5.op_list.1.context_module.main.proj.conv.weight", "decoder.stages.5.op_list.1.context_module.main.proj.norm.bias", "decoder.stages.5.op_list.1.context_module.main.proj.norm.weight", "decoder.stages.5.op_list.1.context_module.main.qkv.conv.weight", "decoder.stages.5.op_list.1.local_module.main.depth_conv.conv.bias", "decoder.stages.5.op_list.1.local_module.main.depth_conv.conv.weight", "decoder.stages.5.op_list.1.local_module.main.inverted_conv.conv.bias", "decoder.stages.5.op_list.1.local_module.main.inverted_conv.conv.weight", "decoder.stages.5.op_list.1.local_module.main.point_conv.conv.weight", "decoder.stages.5.op_list.1.local_module.main.point_conv.norm.bias", "decoder.stages.5.op_list.1.local_module.main.point_conv.norm.weight", "decoder.stages.5.op_list.2.context_module.main.aggreg.0.0.weight", "decoder.stages.5.op_list.2.context_module.main.aggreg.0.1.weight", "decoder.stages.5.op_list.2.context_module.main.proj.conv.weight", "decoder.stages.5.op_list.2.context_module.main.proj.norm.bias", "decoder.stages.5.op_list.2.context_module.main.proj.norm.weight", "decoder.stages.5.op_list.2.context_module.main.qkv.conv.weight", "decoder.stages.5.op_list.2.local_module.main.depth_conv.conv.bias", "decoder.stages.5.op_list.2.local_module.main.depth_conv.conv.weight", "decoder.stages.5.op_list.2.local_module.main.inverted_conv.conv.bias", "decoder.stages.5.op_list.2.local_module.main.inverted_conv.conv.weight", "decoder.stages.5.op_list.2.local_module.main.point_conv.conv.weight", "decoder.stages.5.op_list.2.local_module.main.point_conv.norm.bias", "decoder.stages.5.op_list.2.local_module.main.point_conv.norm.weight", "decoder.project_out.op_list.0.bias", "decoder.project_out.op_list.0.weight", "decoder.project_out.op_list.2.conv.bias", "decoder.project_out.op_list.2.conv.weight".
If not, I can take the original code and modify it to Diffusers format (no problem/inconvenience for me).
Also curious why you chose to train a AE over the more common and extensively used variational ones these days :eyes:
Hi @a-r-r-o-w, I have a conversion script, but it is based on the original codebase. I think giving you converted checkpoints may be better. Where do you think should I upload the converted checkpoints?
Different from previous autoencoders that have spatial compression ratio 8, our deep compression autoencoders have spatial compression ratio up to 128, resulting in much fewer tokens to be processed by diffusion models, thus accelerating both diffusion training and inference.
I think just the conversion script would be super helpful for now as I started modifying the original code in the interest of time. The current checkpoint (if it works for this PR) could be uploaded to your personal HF account, if you've verified that it behaves exactly the same as original checkpoint
Also, if I understand correctly, this is the main checkpoint that we need to support for Sana, yes? https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0
The other checkpoints are different variants of DCAE that you would like to support from Diffusers, but the Sana checkpoint is the main, yes? If so, I will try to remove all the branches of code and first support Sana in the minimal Diffusers-style implementation. Once we complete that, we can work on the Sana pipeline PR and eventually pick up the other DC autoencoders. Does that work? I will keep the implementation generic enough that newer things can be added easily, so we don't have to worry about incompatibility much
Sure. Here is the conversion script, which should be able to run in the efficientvit repo.
import os
import sys
from dataclasses import dataclass
from omegaconf import OmegaConf, MISSING
from typing import Any
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
import ipdb
# from efficientvit.models.efficientvit.dc_ae import DCAE, DCAEConfig
from efficientvit.ae_model_zoo import create_dc_ae_model_cfg, DCAE_HF
from efficientvit.models.utils.network import get_submodule_weights
from efficientvit.models.nn.ops import ResidualBlock, ResBlock, ConvPixelUnshuffleDownSampleLayer, ConvLayer, EfficientViTBlock, ConvPixelShuffleUpSampleLayer, InterpolateConvUpSampleLayer
from efficientvit.models.nn.norm import TritonRMSNorm2d
sys.path.append("/home/junyuc/workspace/code/diffusers")
from src.diffusers.models.autoencoders.autoencoder_dc import DCAE as DCAE_diffusers
@dataclass
class ConvertConfig:
model_name: str = MISSING
def modify_prefix(old_prefix: str, new_prefix: str, state_dict):
for key in list(state_dict.keys()):
key: str
if key.startswith(old_prefix):
state_dict[new_prefix+key.removeprefix(old_prefix)] = state_dict.pop(key)
def trms2d_to_rms2d(norm: Any):
if isinstance(norm, str):
if norm == "trms2d":
norm = "rms2d"
return norm
else:
return list(trms2d_to_rms2d(norm_) for norm_ in norm)
def main():
cfg: ConvertConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(ConvertConfig), OmegaConf.from_cli()))
# model_cfg = create_dc_ae_model_cfg(cfg.model_name, cfg.input_path)
# model = DCAE(model_cfg)
model = DCAE_HF.from_pretrained(f"mit-han-lab/{cfg.model_name}")
model_cfg = model.cfg
model_cfg.encoder.block_type = [block_type.replace("_", "") for block_type in model_cfg.encoder.block_type]
model_cfg.decoder.block_type = [block_type.replace("_", "") for block_type in model_cfg.decoder.block_type]
model_diffusers = DCAE_diffusers(
model_cfg.in_channels, model_cfg.latent_channels,
model_cfg.encoder.block_type, model_cfg.encoder.width_list, model_cfg.encoder.depth_list, trms2d_to_rms2d(model_cfg.encoder.norm), model_cfg.encoder.act, model_cfg.encoder.downsample_block_type,
model_cfg.decoder.block_type, model_cfg.decoder.width_list, model_cfg.decoder.depth_list, trms2d_to_rms2d(model_cfg.decoder.norm), model_cfg.decoder.act, model_cfg.decoder.upsample_block_type
)
state_dict = model.state_dict()
for key in list(state_dict.keys()):
if key.startswith("discriminator") or key.startswith("perceptual_loss"):
del state_dict[key]
modify_prefix("encoder.project_in.conv", "encoder.project_in", state_dict)
for stage_id, stage in enumerate(model.encoder.stages):
num_blocks = len(stage.op_list)
if num_blocks == 0:
continue
for block_id, block in enumerate(stage.op_list):
if isinstance(block, ResidualBlock):
if isinstance(block.main, ConvPixelUnshuffleDownSampleLayer):
modify_prefix(f"encoder.stages.{stage_id}.op_list.{block_id}.main.conv.conv", f"encoder.stages.{stage_id}.{block_id}.main.conv", state_dict)
elif isinstance(block.main, ConvLayer):
modify_prefix(f"encoder.stages.{stage_id}.op_list.{block_id}.main.conv", f"encoder.stages.{stage_id}.{block_id}.main", state_dict)
elif isinstance(block.main, ResBlock):
modify_prefix(f"encoder.stages.{stage_id}.op_list.{block_id}.main.conv1", f"encoder.stages.{stage_id}.{block_id}.conv1", state_dict)
modify_prefix(f"encoder.stages.{stage_id}.op_list.{block_id}.main.conv2", f"encoder.stages.{stage_id}.{block_id}.conv2", state_dict)
else:
raise ValueError(f"encoder block main {block.main} is not supported")
elif isinstance(block, EfficientViTBlock):
modify_prefix(f"encoder.stages.{stage_id}.op_list.{block_id}", f"encoder.stages.{stage_id}.{block_id}", state_dict)
modify_prefix(f"encoder.stages.{stage_id}.{block_id}.local_module.main", f"encoder.stages.{stage_id}.{block_id}.local_module", state_dict)
modify_prefix(f"encoder.stages.{stage_id}.{block_id}.context_module.main", f"encoder.stages.{stage_id}.{block_id}.context_module", state_dict)
modify_prefix(f"encoder.stages.{stage_id}.{block_id}.context_module.qkv.conv", f"encoder.stages.{stage_id}.{block_id}.context_module.qkv.0", state_dict)
modify_prefix(f"encoder.stages.{stage_id}.{block_id}.context_module.qkv.norm", f"encoder.stages.{stage_id}.{block_id}.context_module.qkv.1", state_dict)
modify_prefix(f"encoder.stages.{stage_id}.{block_id}.context_module.proj.conv", f"encoder.stages.{stage_id}.{block_id}.context_module.proj.0", state_dict)
modify_prefix(f"encoder.stages.{stage_id}.{block_id}.context_module.proj.norm", f"encoder.stages.{stage_id}.{block_id}.context_module.proj.1", state_dict)
else:
raise ValueError(f"encoder block {block} is not supported")
modify_prefix(f"encoder.project_out.main.op_list.0", f"encoder.project_out.main.0", state_dict)
for stage_id, stage in enumerate(model.decoder.stages):
num_blocks = len(model.decoder.stages[stage_id].op_list)
if num_blocks == 0:
continue
for block_id, block in enumerate(stage.op_list):
if isinstance(block, ResidualBlock):
if isinstance(block.main, (ConvPixelShuffleUpSampleLayer, InterpolateConvUpSampleLayer)):
modify_prefix(f"decoder.stages.{stage_id}.op_list.{block_id}.main.conv.conv", f"decoder.stages.{stage_id}.{block_id}.main.conv", state_dict)
elif isinstance(block.main, ResBlock):
modify_prefix(f"decoder.stages.{stage_id}.op_list.{block_id}.main.conv1", f"decoder.stages.{stage_id}.{block_id}.conv1", state_dict)
modify_prefix(f"decoder.stages.{stage_id}.op_list.{block_id}.main.conv2", f"decoder.stages.{stage_id}.{block_id}.conv2", state_dict)
else:
raise ValueError(f"decoder block main {block.main} is not supported")
elif isinstance(block, EfficientViTBlock):
modify_prefix(f"decoder.stages.{stage_id}.op_list.{block_id}", f"decoder.stages.{stage_id}.{block_id}", state_dict)
modify_prefix(f"decoder.stages.{stage_id}.{block_id}.local_module.main", f"decoder.stages.{stage_id}.{block_id}.local_module", state_dict)
modify_prefix(f"decoder.stages.{stage_id}.{block_id}.context_module.main", f"decoder.stages.{stage_id}.{block_id}.context_module", state_dict)
modify_prefix(f"decoder.stages.{stage_id}.{block_id}.context_module.qkv.conv", f"decoder.stages.{stage_id}.{block_id}.context_module.qkv.0", state_dict)
modify_prefix(f"decoder.stages.{stage_id}.{block_id}.context_module.qkv.norm", f"decoder.stages.{stage_id}.{block_id}.context_module.qkv.1", state_dict)
modify_prefix(f"decoder.stages.{stage_id}.{block_id}.context_module.proj.conv", f"decoder.stages.{stage_id}.{block_id}.context_module.proj.0", state_dict)
modify_prefix(f"decoder.stages.{stage_id}.{block_id}.context_module.proj.norm", f"decoder.stages.{stage_id}.{block_id}.context_module.proj.1", state_dict)
else:
raise ValueError(f"decoder block {block} is not supported")
for block_id, block in enumerate(model.decoder.project_out.op_list):
if isinstance(block, (ConvLayer, torch.nn.ReLU, TritonRMSNorm2d)):
modify_prefix(f"decoder.project_out.op_list.{block_id}", f"decoder.project_out.{block_id}", state_dict)
elif isinstance(block, ConvPixelShuffleUpSampleLayer):
modify_prefix(f"decoder.project_out.op_list.{block_id}.conv", f"decoder.project_out.{block_id}", state_dict)
else:
raise ValueError(f"decoder project out block {block} is not supported")
model_diffusers.load_state_dict(state_dict, strict=True)
torch.set_grad_enabled(False)
device = torch.device("cuda")
dtype = torch.float16
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
image = Image.open("assets/fig/girl.png")
x = transform(image)[None].to(device=device, dtype=dtype)
model_diffusers = model_diffusers.to(device=device, dtype=dtype).eval()
latent = model_diffusers.encode(x)
print(latent.shape)
y = model_diffusers.decode(latent).sample
save_image(y * 0.5 + 0.5, f"demo.jpg")
if __name__ == "__main__":
main()
"""
python convert_checkpoint.py model_name=dc-ae-f32c32-sana-1.0
python convert_checkpoint.py model_name=dc-ae-f32c32-in-1.0
"""
The other checkpoints are different variants of DCAE that you would like to support from Diffusers, but the Sana checkpoint is the main, yes? If so, I will try to remove all the branches of code and first support Sana in the minimal Diffusers-style implementation. Once we complete that, we can work on the Sana pipeline PR and eventually pick up the other DC autoencoders. Does that work? I will keep the implementation generic enough that newer things can be added easily, so we don't have to worry about incompatibility much
Agree with it.
My current conversion script is supposed to support all the checkpoints. Also, these checkpoints don't differ a lot. I think it won't lead to much more additional effort to support all the models at the same time. However, if you find that supporting all the models is much harder, I think it's ok to support dc-ae-f32-sana-1.0 first.
We have a safetensor for you to test, here: https://huggingface.co/Efficient-Large-Model/dc_ae_f32c32_sana_1.0_diffusers @a-r-r-o-w
Thanks Junsong! We are very close to being able to merge this. I've simplified the original implementation quite a lot here: https://github.com/huggingface/diffusers/tree/aryan-dcae. I pushed my changes in a separate branch because there was a merge conflict happening when trying to pull latest changes. You can see the diff here.
I think all the DCAE's checkpoints can be supported together without too much difficulty like @chenjy2003 mentioned. Now we are only left with removing some builder methods, cleaning up the ViT classes, logical separation into intermediate blocks if possible, and tests/docs. So far, there is a numerical absmax difference between original implementation and Diffusers implementation of 0.00013, probably because order of operations on GPU can affect precision (but can't see any differences visually). I am heading to bed now, and will be able to complete this by tomorrow, but if you would like to take up anything above, please feel free to.
(copied content from #10064)
I believe this version matches the original Sana VAE checkpoint completely. I am yet to verify the correctness of all the other variants, so I'll share the unit tests after completing this testing.
To run the conversion, I use:
python3 scripts/convert_dcae_to_diffusers.py --vae_ckpt_path /raid/aryan/dc-ae-sana/model.safetensors --output_path /raid/aryan/sana-vae-diffusers
Here is some inference code for testing:
code
import numpy as np
import torch
from diffusers import AutoencoderDC
from diffusers.utils import load_image
from PIL import Image
@torch.no_grad()
def main():
ae = AutoencoderDC.from_pretrained("/raid/aryan/sana-vae-diffusers/")
ae = ae.to("cuda")
image = load_image("inputs/astronaut.jpg").resize((512, 512))
image = np.array(image)
image = torch.from_numpy(image)
image = image / 127.5 - 1.0
image = image.unsqueeze(0).permute(0, 3, 1, 2).to("cuda")
encoded = ae.encode(image)
print("encoded:", encoded.shape)
decoded = ae.decode(encoded)
print("decoded:", decoded.shape)
output = decoded[0].permute(1, 2, 0)
output = (output + 1) / 2.0 * 255.0
output = output.clamp(0.0, 255.0)
output = output.detach().cpu().numpy().astype(np.uint8)
output = Image.fromarray(output)
output.save("output.png")
original_encoded = torch.load("original_dcae_encoded.pt", weights_only=True)
original_decoded = torch.load("original_dcae_decoded.pt", weights_only=True)
encoded_diff = encoded - original_encoded
decoded_diff = decoded - original_decoded
print(encoded_diff.abs().max(), encoded_diff.abs().sum())
print(decoded_diff.abs().max(), decoded_diff.abs().sum())
main()
| Original image | Reconstruction |
|---|---|
I think it is okay to skip the diffusers-side VAE tests for now, and pick it up in a follow up PR after #9808 is merged. Will add the documentation after verifying all checkpoints work as expected and finalizing the diffusers implementation following reviews.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@a-r-r-o-w Hi Aryan, may I have a question here? I noticed that only block_out_channels are shared between the encoder and the decoder, while all other args like block_types, layers_per_block, and qkv_multiscales are separated. I'm wondering if it would be better to also separate block_out_channels.
@chenjy2003 That sounds good! I think I made that simplification because I noticed that both were the same for encoder/decoder, but keeping the distinction for this should be okay too.
Hi @lawrence-cj @chenjy2003! Apologies for the delay here, but I think we should be good for a full review soon by @yiyixuxu. I've written some unit tests to verify that all original checkpoints work with diffusers adaption.
For converting a checkpoint, here's an example command:
python3 scripts/convert_dcae_to_diffusers.py --config_name dc-ae-f128c512-mix-1.0 --output_path /raid/aryan/dc-ae/dc-ae-f128c512-mix-1.0-diffusers
This is the unit test code:
unit test
import numpy as np
import torch
from diffusers import AutoencoderDC
from diffusers.utils import load_image
from diffusers.utils.testing_utils import enable_full_determinism
from PIL import Image
enable_full_determinism()
model_names = [
"dc-ae-f32c32-sana-1.0"
"dc-ae-f32c32-in-1.0",
"dc-ae-f32c32-mix-1.0",
"dc-ae-f64c128-in-1.0",
"dc-ae-f64c128-mix-1.0",
"dc-ae-f128c512-in-1.0",
"dc-ae-f128c512-mix-1.0",
]
@torch.no_grad()
def main(model_name: str):
print("Processing:", model_name)
ae = AutoencoderDC.from_pretrained(f"/raid/aryan/dc-ae/{model_name}-diffusers")
ae = ae.to("cuda")
model_name = model_name.replace(".", "-")
image = load_image("inputs/astronaut.jpg").resize((512, 512))
image = np.array(image)
image = torch.from_numpy(image)
image = image / 127.5 - 1.0
image = image.unsqueeze(0).permute(0, 3, 1, 2).to("cuda")
encoded = ae.encode(image)
print("encoded:", encoded.shape)
decoded = ae.decode(encoded)
print("decoded:", decoded.shape)
output = decoded[0].permute(1, 2, 0)
output = (output + 1) / 2.0 * 255.0
output = output.clamp(0.0, 255.0)
output = output.detach().cpu().numpy().astype(np.uint8)
output = Image.fromarray(output)
output.save(f"output-dcae-{model_name}.png")
original_encoded = torch.load(f"/home/aryan/work/efficientvit/original-encoded-{model_name}.pt", weights_only=True)
original_decoded = torch.load(f"/home/aryan/work/efficientvit/original-decoded-{model_name}.pt", weights_only=True)
encoded_diff = encoded - original_encoded
decoded_diff = decoded - original_decoded
print(encoded_diff.abs().max(), encoded_diff.abs().sum())
print(decoded_diff.abs().max(), decoded_diff.abs().sum())
for model_name in model_names:
main(model_name)
This is the output:
Processing: dc-ae-f32c32-sana-1.0
encoded: torch.Size([1, 32, 16, 16])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.1609, device='cuda:0') tensor(38.5961, device='cuda:0')
tensor(0.0196, device='cuda:0') tensor(166.9286, device='cuda:0')
Processing: dc-ae-f32c32-in-1.0
encoded: torch.Size([1, 32, 16, 16])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0055, device='cuda:0') tensor(3.0150, device='cuda:0')
tensor(0.0109, device='cuda:0') tensor(46.8522, device='cuda:0')
Processing: dc-ae-f32c32-mix-1.0
encoded: torch.Size([1, 32, 16, 16])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0017, device='cuda:0') tensor(1.6487, device='cuda:0')
tensor(0.0074, device='cuda:0') tensor(52.5738, device='cuda:0')
Processing: dc-ae-f64c128-in-1.0
encoded: torch.Size([1, 128, 8, 8])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0083, device='cuda:0') tensor(5.2498, device='cuda:0')
tensor(0.0099, device='cuda:0') tensor(59.4267, device='cuda:0')
Processing: dc-ae-f64c128-mix-1.0
encoded: torch.Size([1, 128, 8, 8])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0018, device='cuda:0') tensor(2.4982, device='cuda:0')
tensor(0.0076, device='cuda:0') tensor(64.4642, device='cuda:0')
Processing: dc-ae-f128c512-in-1.0
encoded: torch.Size([1, 512, 4, 4])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0020, device='cuda:0') tensor(2.9806, device='cuda:0')
tensor(0.0103, device='cuda:0') tensor(54.9113, device='cuda:0')
Processing: dc-ae-f128c512-mix-1.0
encoded: torch.Size([1, 512, 4, 4])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0058, device='cuda:0') tensor(5.2506, device='cuda:0')
tensor(0.0144, device='cuda:0') tensor(69.1719, device='cuda:0')
The left values are absmax difference from original encoded latent/decoded image. The right values is the abssum difference. Typically, we usually try and make the absmax difference lower than 1e-3 for most of our integrations. This is because in pixel space, a latent difference of 0.001 is equivalent to 0.255, which when converted to uint8 is a difference of 0.
Here, I am seeing a slightly higher difference than what I would consider okay, which could be because of a mistake on my part when doing the diffusers adaption. After spending some time on this, I haven't been able to find any difference compared to original implementation, atleast not from just looking at the code. I will continue to look for what went wrong for some more time, but any help/insights would be super helpful!
To generate the latent/image tensors to compare to, I use the code from the original implementation examples:
original code
from efficientvit.ae_model_zoo import DCAE_HF
from PIL import Image
import torch
import torchvision.transforms as transforms
from torchvision.utils import save_image
from efficientvit.apps.utils.image import DMCrop
import numpy as np
model_names = [
"dc-ae-f32c32-sana-1.0"
"dc-ae-f32c32-in-1.0",
"dc-ae-f32c32-mix-1.0",
"dc-ae-f64c128-in-1.0",
"dc-ae-f64c128-mix-1.0",
"dc-ae-f128c512-in-1.0",
"dc-ae-f128c512-mix-1.0",
]
@torch.no_grad()
def main():
for model_name in model_names:
print("Processing:", model_name)
dc_ae = DCAE_HF.from_pretrained(f"mit-han-lab/{model_name}")
device = torch.device("cuda")
dc_ae = dc_ae.to(device).eval()
image = Image.open("/home/aryan/work/diffusers/inputs/astronaut.jpg").resize((512, 512))
image = np.array(image)
image = torch.from_numpy(image)
image = image / 127.5 - 1.0
image = image.unsqueeze(0).permute(0, 3, 1, 2).contiguous().to("cuda")
encoded = dc_ae.encode(image)
torch.save(encoded, f"original-encoded-{model_name.replace('.', '-')}.pt")
print("encoded:", encoded.shape)
decoded = dc_ae.decode(encoded)
torch.save(decoded, f"original-decoded-{model_name.replace('.', '-')}.pt")
print("decoded:", decoded.shape)
save_image((decoded + 1) / 2, f"original-output-{model_name.replace('.', '-')}.png")
main()
For convenience, and to make it easier for others from the team to test, I have uploaded the current version of converted checkpoints to https://huggingface.co/a-r-r-o-w/dcae-diffusers/tree/main/. This will be deleted later, or we can move them to your organization.
Even though we seem to have some numerical difference, I believe the conversion is mostly correct. Here is the visual comparison:
| Original |
|---|
| F32C32-SANA |
| F32C32-in |
| F32C32-mix |
| F64C128-in |
| F64C128-mix |
| F128C512-in |
| F128C512-mix |
Actually, we have a LiteMLA processor here: https://github.com/huggingface/diffusers/blob/996606edb19aa0a0ccc1f65fd75c41f33e3a229b/src/diffusers/models/attention_processor.py#L5038
You can check if this one is helpful. @a-r-r-o-w @yiyixuxu
@lawrence-cj Thanks, that helps! I see that the transformer also uses LiteMLA. Initially, I missed this so moved that into the same file as autoencoder, but I think it's better suited to exist in the attention.py file now (not completely sure, I will consult @DN6 on how to go about it). Will make this update along with splitting QKV into individual projection layers and push the update shortly.
@a-r-r-o-w Hi Aryan, shall we follow the same output format as other autoencoders? Junsong once asked me to follow Encode, Decode.
@chenjy2003 Yes, that sounds good! We do need those output wrappers for compatibility across the implementations.
On another note, I've update the code to split the convolution QKV into individual nn.Linear projections. After making this change, the absmax and abssum are as follows:
Processing: dc-ae-f32c32-sana-1.0
encoded: torch.Size([1, 32, 16, 16])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.1624, device='cuda:0') tensor(38.6361, device='cuda:0')
tensor(0.0244, device='cuda:0') tensor(166.1487, device='cuda:0')
Processing: dc-ae-f32c32-in-1.0
encoded: torch.Size([1, 32, 16, 16])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0035, device='cuda:0') tensor(3.1268, device='cuda:0')
tensor(0.0080, device='cuda:0') tensor(46.8220, device='cuda:0')
Processing: dc-ae-f32c32-mix-1.0
encoded: torch.Size([1, 32, 16, 16])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0017, device='cuda:0') tensor(1.7227, device='cuda:0')
tensor(0.0062, device='cuda:0') tensor(53.6814, device='cuda:0')
Processing: dc-ae-f64c128-in-1.0
encoded: torch.Size([1, 128, 8, 8])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0070, device='cuda:0') tensor(5.2894, device='cuda:0')
tensor(0.0098, device='cuda:0') tensor(58.5385, device='cuda:0')
Processing: dc-ae-f64c128-mix-1.0
encoded: torch.Size([1, 128, 8, 8])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0021, device='cuda:0') tensor(2.8276, device='cuda:0')
tensor(0.0098, device='cuda:0') tensor(67.3011, device='cuda:0')
Processing: dc-ae-f128c512-in-1.0
encoded: torch.Size([1, 512, 4, 4])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0024, device='cuda:0') tensor(3.0521, device='cuda:0')
tensor(0.0107, device='cuda:0') tensor(57.2206, device='cuda:0')
Processing: dc-ae-f128c512-mix-1.0
encoded: torch.Size([1, 512, 4, 4])
decoded: torch.Size([1, 3, 512, 512])
tensor(0.0062, device='cuda:0') tensor(5.4188, device='cuda:0')
tensor(0.0064, device='cuda:0') tensor(68.8127, device='cuda:0')
I think the differences are in tolerable range except for the Sana VAE. Still investigating what's the cause (visually, the quality looks similar to original checkpoints however). I suspect it is a missing upcast somewhere or the alternate implementation of norms, but will be unsure until I debug this layerwise
cc @DN6 I think we could benefit from having .from_single_file loading for the autoencoders from this PR, because the original checkpoints have a good amount of downloads.
@lawrence-cj Please feel free to mark this as "Ready for review" and we can try and merge this by weekend if all is well
I think the differences are in tolerable range except for the Sana VAE. Still investigating what's the cause (visually, the quality looks similar to original checkpoints however). I suspect it is a missing upcast somewhere or the alternate implementation of norms, but will be unsure until I debug this layerwise
@a-r-r-o-w What is the precision are you using? I meet problem when I was using FP16 once. Inference with BF16 or FP32 will be great.
Please feel free to mark this as "Ready for review" and we can try and merge this by weekend if all is well
Cool cool. Let's do it!
@a-r-r-o-w What is the precision are you using? I meet problem when I was using FP16 once. Inference with BF16 or FP32 will be great.
These reported numbers are from FP32 inference, but BF16 values differ a bit more. Still investigating what's causing this. Differences like 0.0064 (for f128c512-mix as example) in the final tensor are not really a problem, as 0 to 255 is spread in the range -1 to 1, so it is equivalent to 0.0064 / 2 * 255.0 ~ 0.8, which when casted to uint8 will be either a difference of 0 or 1 in pixel values.
But absmax difference here is not a very good metric to compare the results. Something like SSIM is high between original AE version and Diffusers version, so I think everything is good.
Cool. Then, let's try to figure out if the result is consistent together with Sana later.
Loading from the original checkpoints is also supported now by performing the conversion on-the-fly!
test code
import numpy as np
import torch
from diffusers import AutoencoderDC
from diffusers.utils import load_image
from diffusers.utils.testing_utils import enable_full_determinism
from PIL import Image
enable_full_determinism()
model_names = [
"dc-ae-f32c32-sana-1.0",
"dc-ae-f32c32-in-1.0",
"dc-ae-f32c32-mix-1.0",
"dc-ae-f64c128-in-1.0",
"dc-ae-f64c128-mix-1.0",
"dc-ae-f128c512-in-1.0",
"dc-ae-f128c512-mix-1.0",
]
@torch.no_grad()
def main(model_name: str):
print("Processing:", model_name, '\n')
ae = AutoencoderDC.from_single_file(
f"https://huggingface.co/mit-han-lab/{model_name}/model.safetensors",
original_config=f"https://huggingface.co/mit-han-lab/{model_name}/resolve/main/config.json"
)
ae = ae.to("cuda")
model_name = model_name.replace(".", "-")
image = load_image("inputs/astronaut.jpg").resize((512, 512))
image = np.array(image)
image = torch.from_numpy(image)
image = image / 127.5 - 1.0
image = image.unsqueeze(0).permute(0, 3, 1, 2).to("cuda")
encoded = ae.encode(image)
print("encoded:", encoded.shape)
decoded = ae.decode(encoded)
print("decoded:", decoded.shape)
output = decoded[0].permute(1, 2, 0)
output = (output + 1) / 2.0 * 255.0
output = output.clamp(0.0, 255.0)
output = output.detach().cpu().numpy().astype(np.uint8)
output = Image.fromarray(output)
output.save(f"output-dcae-{model_name}.png")
original_encoded = torch.load(f"/home/aryan/work/efficientvit/original-encoded-{model_name}.pt", weights_only=True)
original_decoded = torch.load(f"/home/aryan/work/efficientvit/original-decoded-{model_name}.pt", weights_only=True)
encoded_diff = encoded - original_encoded
decoded_diff = decoded - original_decoded
print(encoded_diff.abs().max(), encoded_diff.abs().sum())
print(decoded_diff.abs().max(), decoded_diff.abs().sum())
print()
for model_name in model_names:
main(model_name)
cc @DN6 for single file review
Most of the work are done by @chenjy2003 and @a-r-r-o-w. Thank you all for you hard and awesome work. @chenjy2003 junyu please check the overall PR and decide if itβs ok with you.
ohh my bad, I got confused π thanks @chenjy2003!!!! let us know if the PR is ok to merge )
Thanks for the hard work! Let me double-check this PR.