ZipAR icon indicating copy to clipboard operation
ZipAR copied to clipboard

Question about the implementation

Open csgeekhuang opened this issue 4 months ago • 4 comments

I want to ask, why when the next row began to generate, we just need the last token in the last row? How about another Token in this row that have not generated? use zero? Image

csgeekhuang avatar Aug 14 '25 09:08 csgeekhuang

This is exactly the motivation of our paper. We do not need these tokens due to the strong locality of images. As for the last token in the previous row, it is required to initialize the generation of the new row, following a next-token prediction paradigm.

ThisisBillhe avatar Aug 14 '25 09:08 ThisisBillhe

So why just need the last token instead of 2 or more? since there are still lot of token in the last row have not been predicted. And I think Next-Token paradigm do not mean just need the last one, but all of the forehead token, so "it is required to initialize the generation of the new row, following a next-token prediction paradigm" I think maybe a little problem

csgeekhuang avatar Aug 14 '25 09:08 csgeekhuang

In the next-token prediction paradigm, we need to input the last token as a query so that we can generate the next token. As for the rest of the prior tokens, whether attending to them or not is optional during the generation of new tokens. For example, the attention sink proposes only attending to the first few tokens.

ThisisBillhe avatar Aug 14 '25 09:08 ThisisBillhe

Thank you for you kindly response! Still one question: Code: # forward input preparation new_input_ids.append(input_ids[:, idx_in_input_ids].unsqueeze(-1)) new_position_ids.append(global_idx) local_position_ids.append(idx_in_input_ids) new_input_ids = torch.cat(new_input_ids, dim=1) new_position_ids = torch.tensor(new_position_ids, device=input_ids.device).unsqueeze_(0) local_position_ids = torch.tensor(local_position_ids, device=input_ids.device).unsqueeze_(0) num_new_tokens = new_input_ids.shape[1]

        ## model forward
        if cfg_scale > 1.0:
            x_combined = torch.cat([new_input_ids, new_input_ids])
            logits, _ = model(x_combined, cond_idx=None, input_pos=new_position_ids[0])

I notice the input_pos used in the forward process is actually the global idx(true idx), so there should be some place that still be
None(have not generated) # mask during inference bs = token_embeddings.shape[0] mask = self.causal_mask[:bs, None, input_pos] h = self.tok_dropout(token_embeddings) self.freqs_cis = self.freqs_cis And the mask matric is used traditioanlly, so how can you process the token have not generated? eg. Token 8.0, we initial the Token 7.15, but what about Token 7.14 and Token 7.13? since i see the window size is just 8

csgeekhuang avatar Aug 14 '25 09:08 csgeekhuang