gritlm icon indicating copy to clipboard operation
gritlm copied to clipboard

bidirectional attention or casual attention for embedding?

Open yonxie opened this issue 1 year ago • 5 comments

You mention that bidirectional attention is used for embedding task. But it appears that you only use the last hidden states from the pretrained LLM to generate embeddings. Is the final projection is the only bidirectional part?

yonxie avatar Mar 15 '24 23:03 yonxie

The last hidden state is produced via bidirectional attention in the model itself

Muennighoff avatar Mar 15 '24 23:03 Muennighoff

Hi, I'm currently trying to train gritlm using Gemma2b to generate embeddings. While reviewing the training script for Mistral7b, I noticed the use of bidirectional attention with attn='bbcc'. In the context of embeddings, would it be more advantageous to train with 'bbcc' or 'cccc'?

However, when I tried to use attn='bbcc' with Gemma, I encountered an error: TypeError: GemmaModel.forward() received an unexpected keyword argument 'is_causal'. To fix this, I commented out the following line in gritlm.py:

if (self.attn is not None) and (self.attn[:2] == 'bb'): inputs["is_causal"] = False

is this correct ?

Hisarlik avatar Apr 09 '24 16:04 Hisarlik

bbcc is better & commenting out that line will make it equivalent to cccc so it's not a good idea, also see https://github.com/ContextualAI/gritlm/issues/24

Muennighoff avatar Apr 09 '24 16:04 Muennighoff

Hi @Muennighoff, amazing work! I have a similar confusing as @yonxie. I can see here that you did a final pooling. You mentioned that "The last hidden state is produced via bidirectional attention in the model itself". Would you mind pointing out where this is done?

I was also looking at the query-doc cacheing example at page 63. In order to reuse the key-value cache (if I understand correctly the key values are producing during forward pass using bidirectional attention), that means GRIT GRITLM functions as a prefixLM with two independent prefixes during RAG?

Vincent-Li-9701 avatar Apr 17 '24 04:04 Vincent-Li-9701

Sorry for the confusion. I mean that inside of the model bidirectional attention is applied in every transformer layer. The attention mask for that is created here https://github.com/ContextualAI/gritlm/blob/47b7fe6c7109ba46b82b68c37d32aa9a8bf010c5/scripts/modeling_mistral_gritlm.py#L1018

The pooling that you point to is then applied to the final hidden state returned from the model to remove the sequence length dimension.

if I understand correctly the key values are producing during forward pass using bidirectional attention

Yes

that means GRIT GRITLM functions as a prefixLM with two independent prefixes during RAG?

The two caches (or prefixes if you will) are concatenated and have not paid attention to one another (maybe this is what you mean by independent). You may find it helpful to look at this code example: https://github.com/ContextualAI/gritlm?tab=readme-ov-file#caching

Muennighoff avatar Apr 17 '24 14:04 Muennighoff