OnnxStream icon indicating copy to clipboard operation
OnnxStream copied to clipboard

How to convert a VAE?

Open l3utterfly opened this issue 9 months ago • 8 comments

I'm trying to convert a stable diffusion 1.5 vae from safetensors to onnxstream. I'm using code similar to this: https://github.com/AeroX2/safetensor2onnx2txt/blob/main/safetensor2onnx.py

The conversion process runs fine, I get the onnxstream txt files. But it seems to produce wrong results compared to your provided vae decoder.

In particular, I noticed your vae decoder has a 1D conv at the very start, which seems to be a quantisation step? How did you convert your vae decoder?

I tested using your unet model + your vae decoder, the result is correct. But using your unet + MY vae decoder, produces an image which seems to have it's colours messed up. So I'm wondering what's the difference here? Are there any extra steps/networks built into your converted vae decoder?

l3utterfly avatar Mar 10 '25 13:03 l3utterfly

no, absolutely nothing special: torch.onnx.export + onnxsim_large_model + onnx2txt (in this order).

Can you share the model you are trying to convert and especially the code that calls torch.onnx.export?

Vito

vitoplantamura avatar Mar 10 '25 20:03 vitoplantamura

This is the code which exports from a safetensor (the args.input is the path to the safetensor file):

pipe = StableDiffusionPipeline.from_single_file(args.input)

    latent = int(int(args.width) / 8)
    
    # Export VAE decoder
    dummy_input = (torch.randn(1, 4, latent, latent),)
    input_names = ["latent_sample"]
    output_names = ["sample"]
    torch.onnx.export(pipe.vae.decoder, dummy_input, args.output, verbose=False, 
                    input_names=input_names, output_names=output_names, 
                    opset_version=14, do_constant_folding=True, export_params=True)

I'm using the base model: https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main

The file: v1-5-pruned-emaonly.safetensors

After exporting to onnx using the script above, I'm using onnxsim_large and onnx2txt in that order

l3utterfly avatar Mar 11 '25 05:03 l3utterfly

To add to the above, the process works well for converting unet models:

pipe = StableDiffusionPipeline.from_single_file(args.input)

    latent = int(int(args.width) / 8)

    dummy_input = (torch.randn(1, 4, latent, latent), torch.randn(1), torch.randn(1, 77, 768))
    input_names = ["sample", "timestep", "encoder_hidden_states"]
    output_names = ["out_sample"]

    torch.onnx.export(pipe.unet, dummy_input, args.output, verbose=False, input_names=input_names, output_names=output_names, opset_version=14, do_constant_folding=True, export_params=True)
    ```

I'm not sure what's the difference with vae models

l3utterfly avatar Mar 11 '25 05:03 l3utterfly

I found the code I originally used to export the VAE model:

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

class VAED(nn.Module):
    def __init__(self, vae):
        super(VAED, self).__init__()
        self.vae = vae
    def forward(self, latents):
        self.vae.enable_slicing()
        self.vae.enable_tiling()
        image = self.vae.decode(latents, return_dict=False)[0] # / self.vae.config.scaling_factor
        return image

with torch.no_grad():

    dummy_input = ( torch.randn(1, 4, 128, 128), )
    input_names = [ "latent_sample" ]
    output_names = [ "sample" ]
    
    torch.onnx.export(VAED(pipe.vae), dummy_input, "/home/vito/Downloads/VAED/vaed.onnx", verbose=False,
        input_names=input_names, output_names=output_names,
        opset_version=14, do_constant_folding=True)

Unfortunately I have no way to test it at the moment. Probably the only thing to change is the size of dummy_input.

Vito

vitoplantamura avatar Mar 11 '25 19:03 vitoplantamura

This works! The converted model has the same difference as mine (the added 1D conv layer at the top).

I'm not too familiar with the inner workings of the diffusers, would you know why your script produces different results to mines? Is it the enable_slicing/tiling call in the forward method? What is the significance of that?

l3utterfly avatar Mar 12 '25 09:03 l3utterfly

I think the first thing to do to try to understand the reason is to compare the two model.txt... specifically searching for different, missing or extra operations at the beginning or end of the file.

About enable_slicing/tiling, I don't remember why I added those two calls. To understand it, it is enough to study the implementation of HF Diffusers (unfortunately I don't have the time to do it right now).

Vito

vitoplantamura avatar Mar 13 '25 09:03 vitoplantamura

@l3utterfly, hello, will you mind if I'll add support for your SD 1.5 tiled decoder? (As in this model: https://huggingface.co/l3utterfly/sd-onnxstream/blob/main/dreamshaper_8-layla5_4_0.zip). It's very cool because it reduces memory consumption 2 or 3 times, depending on --rpi option, allowing to run fp16 decoder on 1 Gb devices.

With best regards.

vmobilis avatar Apr 20 '25 20:04 vmobilis

@vmobilis Sure! Let me know if there is anything I can do to help!

l3utterfly avatar Apr 21 '25 04:04 l3utterfly