StreamDiffusion icon indicating copy to clipboard operation
StreamDiffusion copied to clipboard

Readme Example with tensorrt: expected input[2, 4, 64, 64] to have 3 channels, but got 4 channels instead

Open Jannchie opened this issue 2 years ago • 0 comments

I am using the img2img example given in the readme. But it doesn't compile tensorrt successfully. I got such an error:

RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[2, 4, 64, 64] to have 3 channels, but got 4 channels instead

I feel the problem is in the vae decoder compilation part. I found that Wrapper's code seems to be fine, it doesn't use accelerate_with_tensorrt, but I'm not too familiar with this piece and I don't see where the problem is.

example code:

import torch
from diffusers import AutoencoderTiny, StableDiffusionPipeline
from diffusers.utils import load_image

from streamdiffusion import StreamDiffusion
from streamdiffusion.acceleration.tensorrt import accelerate_with_tensorrt
from streamdiffusion.image_utils import postprocess_image

# You can load any models using diffuser's StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("KBlueLeaf/kohaku-v2.1").to(
    device=torch.device("cuda"),
    dtype=torch.float16,
)

# Wrap the pipeline in StreamDiffusion
stream = StreamDiffusion(
    pipe,
    t_index_list=[32, 45],
    torch_dtype=torch.float16,
)

# If the loaded model is not LCM, merge LCM
stream.load_lcm_lora()
stream.fuse_lora()
# Use Tiny VAE for further acceleration
stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(device=pipe.device, dtype=pipe.dtype)
# Enable acceleration
stream = accelerate_with_tensorrt(
    stream, "engines", max_batch_size=2,
)


prompt = "1girl with dog hair, thick frame glasses"
# Prepare the stream
stream.prepare(prompt)

# Prepare image
init_image = load_image("assets/img2img_example.png").resize((512, 512))

# Warmup >= len(t_index_list) x frame_buffer_size
for _ in range(2):
    stream(init_image)

# Run the stream infinitely
while True:
    x_output = stream(init_image)
    postprocess_image(x_output, output_type="pil")[0].show()
    input_response = input("Press Enter to continue or type 'stop' to exit: ")
    if input_response == "stop":
        break

Jannchie avatar Dec 26 '23 02:12 Jannchie