PainlessInferenceAcceleration icon indicating copy to clipboard operation
PainlessInferenceAcceleration copied to clipboard

Changing naive attention to SDPA gives wrong result for batched llama example

Open learning-chip opened this issue 11 months ago • 3 comments

I attempted to swap-in FlashAttention for batched llama, by simply changing self._attn() to self._sdp_attn() inside LlamaAttention.forward():

https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L372-L375

https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L404-L407

where _sdp_attn is defined as:

https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L327-L329

However the model generates wrong result. The original llama_batch_example.py gives:

lookahead:False time:3.326s speed:35.5token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]
...

The modified model gives:

lookahead:False time:3.271s speed:39.1token/s response:['the    “ nobody nobody     “ nobody   “ nobody  “ nobody   “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody ', 'nobody “ nobody “ nobody “ nobody to nobody. Unterscheidung the. Unterscheidung nobody. Unterscheidung ( ,  “ nobody, MS nobody, MS nobodyMS nobodyMS nobodyMS nobodyMS nobodyMS nobody. Unterscheidung,MS nobodyMS nobody. Unterscheidung,MS nobodyMS nobodyMS nobodyMS nobody. UnterscheidungMS nobodyMS']

So LlamaAttention._attn() is doing something extra other than just standard attention?

learning-chip avatar Mar 05 '24 20:03 learning-chip