Ask-Anything icon indicating copy to clipboard operation
Ask-Anything copied to clipboard

Video MiniGPT4

Open pixeli99 opened this issue 1 year ago • 8 comments

Firstly, thanks for your interesting work.

For minigpt4, can it be realized directly using video embedding? Just like,

query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
            query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True,
)
# [bs, num_frames, 32, 768] -> [bs, num_frames, 32, 768] -> [bs, num_frames * 32, 768]
video_out = self.perceive(query_output.last_hidden_state.view(b, t, query_tokens.shape[-2], query_tokens.shape[-1])).flatten(1, 2)
inputs_llama = self.llama_proj(video_out)

As for the self.perceive, Maybe a simple attention will do? Just like flamingo

class PerceiverResampler(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        num_latents = 64,
        num_media_embeds = 4,
        ff_mult = 4
    ):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(num_latents, dim))
        self.media_pos_emb = nn.Parameter(torch.randn(num_media_embeds, 1, dim))

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
                FeedForward(dim = dim, mult = ff_mult)
            ]))

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        if x.ndim == 3:
            x = rearrange(x, 'b n d -> b 1 n d')

        times = x.shape[1]
        x = x + self.media_pos_emb[:times]

        latents = repeat(self.latents, 'n d -> b m n d', b = x.shape[0], m = x.shape[1])

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        return self.norm(latents)

I don't have enough GPUs to verify this idea. Maybe it is very naive. I just put it here and hope to inspire some interested friends.

pixeli99 avatar Apr 26 '23 06:04 pixeli99