StreamDiffusion
StreamDiffusion copied to clipboard
Readme Example with tensorrt: expected input[2, 4, 64, 64] to have 3 channels, but got 4 channels instead
I am using the img2img example given in the readme. But it doesn't compile tensorrt successfully. I got such an error:
RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[2, 4, 64, 64] to have 3 channels, but got 4 channels instead
I feel the problem is in the vae decoder compilation part. I found that Wrapper's code seems to be fine, it doesn't use accelerate_with_tensorrt, but I'm not too familiar with this piece and I don't see where the problem is.
example code:
import torch
from diffusers import AutoencoderTiny, StableDiffusionPipeline
from diffusers.utils import load_image
from streamdiffusion import StreamDiffusion
from streamdiffusion.acceleration.tensorrt import accelerate_with_tensorrt
from streamdiffusion.image_utils import postprocess_image
# You can load any models using diffuser's StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("KBlueLeaf/kohaku-v2.1").to(
device=torch.device("cuda"),
dtype=torch.float16,
)
# Wrap the pipeline in StreamDiffusion
stream = StreamDiffusion(
pipe,
t_index_list=[32, 45],
torch_dtype=torch.float16,
)
# If the loaded model is not LCM, merge LCM
stream.load_lcm_lora()
stream.fuse_lora()
# Use Tiny VAE for further acceleration
stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(device=pipe.device, dtype=pipe.dtype)
# Enable acceleration
stream = accelerate_with_tensorrt(
stream, "engines", max_batch_size=2,
)
prompt = "1girl with dog hair, thick frame glasses"
# Prepare the stream
stream.prepare(prompt)
# Prepare image
init_image = load_image("assets/img2img_example.png").resize((512, 512))
# Warmup >= len(t_index_list) x frame_buffer_size
for _ in range(2):
stream(init_image)
# Run the stream infinitely
while True:
x_output = stream(init_image)
postprocess_image(x_output, output_type="pil")[0].show()
input_response = input("Press Enter to continue or type 'stop' to exit: ")
if input_response == "stop":
break