PainlessInferenceAcceleration
PainlessInferenceAcceleration copied to clipboard
Changing naive attention to SDPA gives wrong result for batched llama example
I attempted to swap-in FlashAttention for batched llama, by simply changing self._attn()
to self._sdp_attn()
inside LlamaAttention.forward()
:
https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L372-L375
https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L404-L407
where _sdp_attn
is defined as:
https://github.com/alipay/PainlessInferenceAcceleration/blob/6280cb2f097ba0bc6bc423ab910b9de7ddbe3bf2/pia/lookahead/models/llama/modeling_llama_batch.py#L327-L329
However the model generates wrong result. The original llama_batch_example.py
gives:
lookahead:False time:3.326s speed:35.5token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]
...
The modified model gives:
lookahead:False time:3.271s speed:39.1token/s response:['the “ nobody nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody “ nobody ', 'nobody “ nobody “ nobody “ nobody to nobody. Unterscheidung the. Unterscheidung nobody. Unterscheidung ( , “ nobody, MS nobody, MS nobodyMS nobodyMS nobodyMS nobodyMS nobodyMS nobody. Unterscheidung,MS nobodyMS nobody. Unterscheidung,MS nobodyMS nobodyMS nobodyMS nobody. UnterscheidungMS nobodyMS']
So LlamaAttention._attn()
is doing something extra other than just standard attention?