CogVideo icon indicating copy to clipboard operation
CogVideo copied to clipboard

Question on "fake_context_parallel_forward" in diffusers implementation

Open wu-qing-157 opened this issue 1 year ago • 3 comments

Hi, thanks for releasing the powerful text-to-video model. I notice that there is a strange point in the diffusers implementation of AutoEncoderKLCogvideoX.

    def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
        kernel_size = self.time_kernel_size
        if kernel_size > 1:
            cached_inputs = (
                [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
            )
            inputs = torch.cat(cached_inputs + [inputs], dim=2)
        return inputs

    def _clear_fake_context_parallel_cache(self):
        del self.conv_cache
        self.conv_cache = None

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        inputs = self.fake_context_parallel_forward(inputs)

        self._clear_fake_context_parallel_cache()
        # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
        # hundred megabytes and so let's not do it for now
        self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()

        padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
        inputs = F.pad(inputs, padding_2d, mode="constant", value=0)

        output = self.conv(inputs)
        return output

If I call multiple times of AutoEncoderKLCogvideoX.encode to encode multiple videos, the conv_cache will not be cleared between the multiple encode call (even cause error when the batch size doesn't match!) If I understand correctly, this behavior will leak some information from the last video to the current video. I'm wondering what's the purpose of this conv_cache, and whether it's an incorrect handling of that in the current diffusers implementation.

wu-qing-157 avatar Sep 06 '24 10:09 wu-qing-157

You need to call this method to clear the cache after each encode/decode for the moment

a-r-r-o-w avatar Sep 07 '24 10:09 a-r-r-o-w

What's this cache aiming for? Does it mean I can call the encode multiple times (split on the n_frame dimension) to lower maximum GPU memory requirements while getting the same results?

wu-qing-157 avatar Sep 10 '24 22:09 wu-qing-157

What's this cache aiming for? Does it mean I can call the encode multiple times (split on the n_frame dimension) to lower maximum GPU memory requirements while getting the same results?

I guess this is aim to realize the 'causal conv', the 3d conv kernel is '333', if you want to keep temporal size, you should padidng it. The code use the cache to padding before the first frame.

DidiD1 avatar Oct 30 '24 08:10 DidiD1