Question about the implementation
I want to ask, why when the next row began to generate, we just need the last token in the last row? How about another Token in this row that have not generated? use zero?
This is exactly the motivation of our paper. We do not need these tokens due to the strong locality of images. As for the last token in the previous row, it is required to initialize the generation of the new row, following a next-token prediction paradigm.
So why just need the last token instead of 2 or more? since there are still lot of token in the last row have not been predicted. And I think Next-Token paradigm do not mean just need the last one, but all of the forehead token, so "it is required to initialize the generation of the new row, following a next-token prediction paradigm" I think maybe a little problem
In the next-token prediction paradigm, we need to input the last token as a query so that we can generate the next token. As for the rest of the prior tokens, whether attending to them or not is optional during the generation of new tokens. For example, the attention sink proposes only attending to the first few tokens.
Thank you for you kindly response! Still one question: Code: # forward input preparation new_input_ids.append(input_ids[:, idx_in_input_ids].unsqueeze(-1)) new_position_ids.append(global_idx) local_position_ids.append(idx_in_input_ids) new_input_ids = torch.cat(new_input_ids, dim=1) new_position_ids = torch.tensor(new_position_ids, device=input_ids.device).unsqueeze_(0) local_position_ids = torch.tensor(local_position_ids, device=input_ids.device).unsqueeze_(0) num_new_tokens = new_input_ids.shape[1]
## model forward
if cfg_scale > 1.0:
x_combined = torch.cat([new_input_ids, new_input_ids])
logits, _ = model(x_combined, cond_idx=None, input_pos=new_position_ids[0])
I notice the input_pos used in the forward process is actually the global idx(true idx), so there should be some place that still be
None(have not generated)
# mask during inference
bs = token_embeddings.shape[0]
mask = self.causal_mask[:bs, None, input_pos]
h = self.tok_dropout(token_embeddings)
self.freqs_cis = self.freqs_cis
And the mask matric is used traditioanlly, so how can you process the token have not generated? eg. Token 8.0, we initial the Token 7.15, but what about Token 7.14 and Token 7.13? since i see the window size is just 8