ComfyUI-DiffusersStableCascade
ComfyUI-DiffusersStableCascade copied to clipboard
High VRAM use - ComfyUI won't unload models
Because the models are loaded directly, ComfyUI model manager doesn't know about them, and can't unload them. There are probably better ways to deal with this and once ComfyUI adds a native version, it shouldn't matter. But in order to run this on my 12GB GPU, I had to unload the models in between phases. Probably a better way to do this, but I'm still pretty new to ComfyUI development, so this was my solution for now. I found that because it took some time to load/unload the models in between, this worked pretty well to run batches, and because the latents are smaller I could run 3-4 at a time without trouble.
def process(self, width, height, seed, steps, guidance_scale, prompt, negative_prompt, batch_size, decoder_steps, image=None):
comfy.model_management.unload_all_models()
torch.manual_seed(seed)
device = comfy.model_management.get_torch_device()
#load the prior
if not hasattr(self, 'prior') or self.prior == None:
self.prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device)
prior_output = self.prior(
image=image,
prompt=prompt,
height=height,
width=width,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=batch_size,
num_inference_steps=steps
)
#unload the prior
if hasattr(self, 'prior'):
self.prior = None
gc.collect()
#load the decoder
if not hasattr(self, 'decoder') or self.decoder == None:
self.decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device)
decoder_output = self.decoder(
image_embeddings=prior_output.image_embeddings.half(),
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
num_inference_steps=decoder_steps
).images
#unload the decoder
if hasattr(self, 'decoder'):
self.decoder = None
gc.collect()
tensors = [ToTensor()(img) for img in decoder_output]
batch_tensor = torch.stack(tensors).permute(0, 2, 3, 1).cpu()
return (batch_tensor,image)
Where to put that ? Maybe a "unload model node would be useful.