nanoGPT icon indicating copy to clipboard operation
nanoGPT copied to clipboard

attn_mask for inference

Open david-waterworth opened this issue 2 years ago • 4 comments

This is more of a question for my understanding. I understand that at training time each sequence is of fixed length (and not padded) so the attention mask can be constructed using a triangular matrix, and when torch 2.0 is available can use the fast scaled_dot_product_attention with is_causal=True and attn_mask=None

I'm not sure about inference though. There doesn't appear to be any way of passing an attn_mask. Are you always expecting singleton batches at inference? For example

def forward(self, idx, targets=None):
       ...
        # inference-time mini-optimization: only forward the lm_head on the very last position
        logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
        loss = None

This will return the logits for the last position along the time dimension for each sequence in the input batch, but if there's more than one batch, and each sequence is of a different length (and is therefore padded,) then you're returning the logits corresponding to the padding idx which probably wasn't even seen during training?

It shouldn't matter if you only ever call inference with batch_size=1, but it seems necessary in the general case to add an attn_mask parameter to the forward of CausalSelfAttention. When training attn_mask should be None, and using the slow path for inference when it's not None would be required. Also instead of x[:, [-1], :] would need to extract the last unmasked position per batch_dim

david-waterworth avatar Apr 23 '23 07:04 david-waterworth

So - you're right about your concerns, but not exactly. I spent much less time on nanoGPT from inference standpoint. Calculating and passing in attention mask is one way to do it. There would be other ways too, but could require adjustments in training phase. The current code will work ok even if batch size > 1, as long as the sequences are "aligned", e.g. if you take a single prompt but want to generate 10 samples in parallel for it. I do agree that for now it could be a good idea to at least have some asserts/warnings in the code around this because it's a bit of a sharp edge.

karpathy avatar Apr 23 '23 16:04 karpathy

Thansk, I Didn't think of the use case of generation in parallel on the same sequence, yes that would certainly work without masking.

I guess it gets a bit complicated to do generation on multiple-length padded sequences anyway - in my case I'm more interested in the logits, I'm working on vector search and am interested in whether I get better representations from a causal language model than a masked language model.

david-waterworth avatar Apr 23 '23 22:04 david-waterworth

Actually, it looks like the norm for GPT like models is to left pad the sequences. That way you can still perform batch generation / select the last logit using x[:, [-1], :] easily.

david-waterworth avatar Apr 26 '23 00:04 david-waterworth