memorizing-transformers-pytorch
memorizing-transformers-pytorch copied to clipboard
Dimensionality of key and values for Attention
I have two questions about the key and value calculation in Attention (and similarly for KNNAttention).
The relevant line is: https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/memorizing_transformers_pytorch.py#L135
- Why is there only one Linear layer
to_kv
, instead of 2 linear layersto_k
andto_v
? - Why is the last dimension
dim_head*2
? I get that *2 is for both k and v, but what about dim_head? I thought q, k, v should all have the same final dimension (i.e.inner_dim==dim_head*heads
). My understanding is that this means that either a) there is only 1 attention head, or for b) all heads, k and v are shared. Is there a reason this is done, or am I misunderstanding?
In your Attention class for Performer, q, k, v all have the same dimensions.
Thanks in advance!
I guess this commit cites the paper that does 1 headed attention: https://github.com/lucidrains/memorizing-transformers-pytorch/commit/9f77fd5e4e449d70c02b9cd25a98e1d5ef5f0a72
@manestay yup, one headed key / values is an old Noam Shazeer paper seeing a resurgence in usage by LLMs such as Alphacode and the 500B parameter PaLM model. it is usually used to save on amount of keys / values cached during inference, but it is a good fit here since we don't need to keep track of head times the faiss indices
Thanks! What about this question: Why is there only one Linear layer to_kv, instead of 2 linear layers to_k and to_v?
@manestay it comes out to be faster if you do one matrix multiplication and then break it up later
I see, thanks!
An unrelated question, just to confirm my understanding, regarding the following line:
https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/knn_memory.py#L153
Can I ask why we have num_indices
(i.e. batch_size
) number of KNN
objects for each KNNMemory
? What does each KNN hold that is different from the other ones? And how does this interact with KNNMemory.add
and KNNMemory.search
, which will add/search each key/query to a different KNN
?
Thanks in advance again @lucidrains
~To provide more context of the KNNMemory.add
function, here's an example on my understanding:~
https://github.com/lucidrains/memorizing-transformers-pytorch/blob/83fa1479d6f7881dd977fbff55681e709e3b250e/memorizing_transformers_pytorch/knn_memory.py#L201
~Suppose we have in the batch (size=4) the key vectors corresponding to the sentences:~
<s> hello
<s> goodbye and
<s> ok .
<s>
~When the above line is called, it will add each key to a different KNN. So I don't get why this is the case -- don't we want all of the keys in the memory? If we have in a later batch a query vector corresponding to <s> goodbye and good
, it seems that the most relevant entry is in the 2nd KNN. But we have no guarantee that this query will be in the 2nd position of the batch.~
~Is my above understanding correct? If so, then I don't see why we have multiple memories, as the paper did not mention that. If not, then please correct me. Thank you.~
EDIT: I see your comment https://github.com/lucidrains/memorizing-transformers-pytorch/issues/1#issuecomment-1086901581 . You said that "they are doing separate indexed memories per batch, which is why I have to instantiate the number of faiss indices equal to the batch size." I guess I am missing something fundamental in my understanding of the whole memorizing transformers approach, since I don't see where they are doing that. Can you point me to the place in the original text? Sorry for asking so many questions.
@manestay I think Figure 3 in the paper answers your question. Each batch appears to steam documents to maintain a consistent within-document memory. And then, a batch of size B contains chunks from B distinct documents that each have their own memory.
@manestay I think Figure 3 in the paper answers your question. Each batch appears to steam documents to maintain a consistent within-document memory. And then, a batch of size B contains chunks from B distinct documents that each have their own memory.
Okay, I see what you're saying. I was having some trouble interpreting this figure, but your explanation makes sense. Thanks!
I guess I was confused by the train.py script in this repo, since it doesn't handle document level stuff, just loads enwik8 sequentially in chunks. But I do see that this is a WIP repo, so maybe that is yet to be implemented. Appreciated once again!