ComfyUI icon indicating copy to clipboard operation
ComfyUI copied to clipboard

VAE decoding becomes very slow when the batch_number in the decode method is large.

Open zhangp365 opened this issue 4 months ago • 0 comments

Expected Behavior

when a large batch size latent images is sent to vae , then decoding, the speed should be normal, about 0.1 seconds per image(512 level)

Actual Behavior

when there are large gpu vram (48G), the batch_number will be large, like 21, and the decoding will be very slow, about 0.8 seconds per image(512 level)

Steps to Reproduce

Actually, I have added some logs as below:

    def decode(self, samples_in):
        try:
            memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
            model_management.load_models_gpu([self.patcher], memory_required=memory_used)
            free_memory = model_management.get_free_memory(self.device)
            logging.info("after free_memory")
            batch_number = int(free_memory / memory_used)
            batch_number = max(1, batch_number)

            pixel_samples = torch.empty((samples_in.shape[0], self.output_channels) + tuple(map(lambda a: a * self.upscale_ratio, samples_in.shape[2:])), device=self.output_device)
            logging.info("after create empty pixel_samples")
            for x in range(0, samples_in.shape[0], batch_number):
                logging.info(f"x:{x} start decode one image.")
                samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
                logging.info(f"x:{x} finish data to decode on batch_number:{batch_number}")
                pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
                logging.info(f"x:{x} finish decode one image.")
        except model_management.OOM_EXCEPTION as e:
            logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
            if len(samples_in.shape) == 3:
                pixel_samples = self.decode_tiled_1d(samples_in)
            else:
                pixel_samples = self.decode_tiled_(samples_in)
        logging.info(f" before move to device{self.output_device}.")
        pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
        logging.info(f" after move to device{self.output_device}.")
        return pixel_samples

after adding, I found the problem.

Debug Logs

On the L20 GPU with 48GB VRAM:
when batch_number is 21:

2024-10-10 10:09:09,541 - execution.py:260 - DEBUG - VAEDecode node got all inputs and start >>>
2024-10-10 10:09:09,542 - sd.py:322 - INFO - after free_memory
2024-10-10 10:09:09,542 - sd.py:327 - INFO - after create empty pixel_samples
2024-10-10 10:09:09,542 - sd.py:329 - INFO - x:0 start decode one image.
2024-10-10 10:09:09,544 - sd.py:331 - INFO - x:0 finish data to decode on batch_number:21
2024-10-10 10:09:23,952 - sd.py:333 - INFO - x:0 finish decode one image.
2024-10-10 10:09:23,952 - sd.py:340 - INFO -  before move to devicecpu.
2024-10-10 10:09:23,952 - sd.py:342 - INFO -  after move to devicecpu.
2024-10-10 10:09:23,952 - execution.py:411 - DEBUG - VAEDecode recursived all inputs and executed ...

With batch_number set to 21, it takes 14.4 seconds to decode 21 images, averaging 0.7 seconds per image, which is very slow.

When batch_number is set to 5 forcibly:

2024-10-10 10:24:55,147 - sd.py:329 - INFO - x:275 start decode one image.
2024-10-10 10:24:55,147 - sd.py:331 - INFO - x:275 finish data to decode on batch_number:5
2024-10-10 10:24:55,538 - sd.py:333 - INFO - x:275 finish decode one image.
2024-10-10 10:24:55,538 - sd.py:329 - INFO - x:280 start decode one image.
2024-10-10 10:24:55,538 - sd.py:331 - INFO - x:280 finish data to decode on batch_number:5
2024-10-10 10:24:55,928 - sd.py:333 - INFO - x:280 finish decode one image.
2024-10-10 10:24:55,929 - sd.py:329 - INFO - x:285 start decode one image.

With batch_number set to 5, it takes 0.4 seconds for 5 images, averaging 0.08 seconds per image.

On a 4090 GPU with 24GB VRAM, forcing batch_number to 20 also slows down the speed, indicating that batch_number might need an upper limit.



### Other

_No response_

zhangp365 avatar Oct 11 '24 04:10 zhangp365