Question on "fake_context_parallel_forward" in diffusers implementation
Hi, thanks for releasing the powerful text-to-video model. I notice that there is a strange point in the diffusers implementation of AutoEncoderKLCogvideoX.
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
kernel_size = self.time_kernel_size
if kernel_size > 1:
cached_inputs = (
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs
def _clear_fake_context_parallel_cache(self):
del self.conv_cache
self.conv_cache = None
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs)
self._clear_fake_context_parallel_cache()
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
# hundred megabytes and so let's not do it for now
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output = self.conv(inputs)
return output
If I call multiple times of AutoEncoderKLCogvideoX.encode to encode multiple videos, the conv_cache will not be cleared between the multiple encode call (even cause error when the batch size doesn't match!)
If I understand correctly, this behavior will leak some information from the last video to the current video. I'm wondering what's the purpose of this conv_cache, and whether it's an incorrect handling of that in the current diffusers implementation.
You need to call this method to clear the cache after each encode/decode for the moment
What's this cache aiming for? Does it mean I can call the encode multiple times (split on the n_frame dimension) to lower maximum GPU memory requirements while getting the same results?
What's this cache aiming for? Does it mean I can call the encode multiple times (split on the n_frame dimension) to lower maximum GPU memory requirements while getting the same results?
I guess this is aim to realize the 'causal conv', the 3d conv kernel is '333', if you want to keep temporal size, you should padidng it. The code use the cache to padding before the first frame.