LLMs-from-scratch
LLMs-from-scratch copied to clipboard
Question about implementation of CausalAttention class (3.5.3 Implementing a compact causal self-attention class)
Hi @rasbt,
This notebook contains the following implementaion of CausalAttention:
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout) # New
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New
def forward(self, x):
b, num_tokens, d_in = x.shape # New batch dimension b
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
attn_scores.masked_fill_( # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights) # New
context_vec = attn_weights @ values
return context_vec
I have a question - why do we need the following 2 lines in the forward()
method implementation:
def forward(self, x):It
b, num_tokens, d_in = x.shape # New batch dimension b
...
attn_scores.masked_fill_( # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
...
Can we remove the first line and just replace the second line to the following code:
attn_scores.masked_fill_(self.mask.bool(), -torch.inf)
As I understand num_tokens = batch_size
and we provide batch_size
value as the argument, so neither calculating x.shape
nor indexing [:num_tokens, :num_tokens]
is required.
Is it correct?
Thank you.
Probably also by the same reason slicing using [:num_tokens, :num_tokens]
is not required in MultiHeadAttention
class in section "3.6.1 Stacking multiple single-head attention layers".
hello @labdmitriy I think that the reason for calculating the num_tokens in the forward method is to ensure that the mask is applied only up to the actual number of tokens in the input sequence. This is important because the mask might be larger than the input sequence if it's pre-allocated based on a fixed block size.
Hello @ahmedDaoudi-u, thank you for your response.
But what is the purpose to use slicing by num_tokens
if it always equals to block_size
in this implementation (as the dimension 1 of the inputs)?
Thanks for bringing this up! Regarding removing the :num_tokens
slicing from
self.mask.bool()[:num_tokens, :num_tokens]
That's unfortunately not possible like @ahmedDaoudi-u mentioned. E.g., in Ch05, we are using an LLM with a block size of 1024, but the input text may be shorter during training and/or inference, and this would then not work. Hence, we need to truncate the mask to the actual size of the input that the LLM sees rather than the maximal input length that the LLM could support.
Let me know if you have any follow up concerns and questions, this is an interesting topic, and I'm happy to discuss!
@rasbt and @ahmedDaoudi-u, thank you for explanations, then probably I will return with additional questions while exploring Chapter 5 :)
Sure @rasbt.