stable-diffusion-webui-tensorrt
stable-diffusion-webui-tensorrt copied to clipboard
SDXL Support
Hello
Is SDXL support planned, as SDXL is slow on most computers?
Kind regards, Timon Käch
Hello
Is SDXL support planned, as SDXL is slow on most computers?
Kind regards, Timon Käch
already try can, but need modify code
speed from 6.67it/s up to 12.10 it/s w 960:h:1024 step 21
1. export to onnx the new method
`import os
from modules import sd_hijack, sd_unet from modules import shared, devices import torch
def export_current_unet_to_onnx(filename, opset_version=17): x = torch.randn(1, 4, 16, 16).to(devices.device, devices.dtype) timesteps = torch.zeros((1,)).to(devices.device, devices.dtype) + 500 context = torch.randn(1, 77, 2048).to(devices.device, devices.dtype) y = torch.randn(1, 2816).to(devices.device, devices.dtype) def disable_checkpoint(self): if getattr(self, 'use_checkpoint', False) == True: self.use_checkpoint = False if getattr(self, 'checkpoint', False) == True: self.checkpoint = False
shared.sd_model.model.diffusion_model.apply(disable_checkpoint)
sd_unet.apply_unet("None")
sd_hijack.model_hijack.apply_optimizations('None')
os.makedirs(os.path.dirname(filename), exist_ok=True)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
shared.sd_model.model.diffusion_model = shared.sd_model.model.diffusion_model.to(device)
with devices.autocast():
torch.onnx.export(
shared.sd_model.model.diffusion_model,
(x, timesteps, context,y),
filename,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['x', 'timesteps', 'context','y'],
output_names=['output'],
dynamic_axes={
'x': {0: 'batch_size', 2: 'height', 3: 'width'},
'timesteps': {0: 'batch_size'},
'context': {0: 'batch_size', 1: 'sequence_length'},
'y':{0:'batch_size'},
'output': {0: 'batch_size'},
},
)
sd_hijack.model_hijack.apply_optimizations()
sd_unet.apply_unet()
`
3.hijack the UNetModel_forwardy,
/modules/sd_hijack.py
` ... if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'): ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
if not hasattr(sgm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
sgm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = sgm.modules.diffusionmodules.openaimodel.UNetModel.forward
sgm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forwardy
def undo_hijack(self, m):
if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
m.cond_stage_model = m.cond_stage_model.wrapped
undo_optimizations()
undo_weighted_forward(m)
self.apply_circular(False)
self.layers = None
self.clip = None
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
sgm.modules.diffusionmodules.openaimodel.UNetModel.forward = sgm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
... `
3. modules/sd_unet.py
` ... class SdUnet(torch.nn.Module): def forward(self, x, timesteps, context, *args, **kwargs): raise NotImplementedError()
def activate(self):
pass
def deactivate(self):
pass
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): if current_unet is not None: return current_unet.forward(x, timesteps, context, *args, **kwargs)
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
def UNetModel_forwardy(self, x, timesteps=None, context=None, y=None, **kwargs): if current_unet is not None: return current_unet.forward(x, timesteps, context, y, **kwargs)
return sgm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context,y, **kwargs)
here can use same method
... `
4. extensions/stable-diffusion-webui-tensorrt/scripts/trt.py
` def forward(self, x, timesteps, context,*args, **kwargs): a,b,c,d=x.shape
#print(x.shape,timesteps.shape,context.shape)
if a==1:
self.infer({"x": x, "timesteps": timesteps, "context": context})
#print(self)
return self.buffers["output"].to(dtype=x.dtype, device=devices.device)
else:
images=[]
for i in range(a):
with contextlib.suppress(Exception):
s = x[i].unsqueeze(0)
t = timesteps[i].unsqueeze(0)
c = context[i].unsqueeze(0)
if args is not None and args.__len__()!=0:
y = args[0][i].unsqueeze(0)
self.infer({"x": s, "timesteps": t, "context": c,"y":y})
#print(self)
else:
self.infer({"x": s, "timesteps": t, "context": c})
tmp_img= self.buffers["output"].to(dtype=x.dtype, device=devices.device)
new_var = tmp_img
images.append(new_var)
return torch.cat(images, dim=0)
`
5. and for found 2 device problem
you need one by one find out it add model.to(devices.device) or easy way use model.cuda() // have maybe 3-4 place need modify
7. export onnx to trt my command
"{full_path}/trtexec" --onnx="{full_path}/models/Unet-onnx/ttt.onnx" --saveEngine="{full_path}/models/Unet-trt/ttt.trt" --minShapes=x:1x4x64x64,context:1x77x2048,timesteps:1 --maxShapes=x:1x4x128x120,context:1x77x2048,timesteps:1 --fp16
Hey, thank you so much for the fast answer. Will try it out soon. Is 1024x1024 not possible? Only 960x1024?
Hey, thank you so much for the fast answer. Will try it out soon. Is 1024x1024 not possible? Only 960x1024?
cant sure ,maxShapes=x:1x4x128x120 cant over this size if use maxShapes=x:1x4x128x128 the trtexec will popup the error