OLMo
OLMo copied to clipboard
PyTorch scaled_dot_product_attention doesn't support broadcast for grouped-query-attention
❓ The question
I tried implementing grouped query attention in this pull request, but seems that pytorch's scaled_dot_product_attention doesn't support the kind of broadcasting we'd need for this. Revisit if/when this gets fixed on pytorch's end.
TODO: decide ourselves how to broadcast K and V tenors to match Q shape when using grouped-query-attention. To be revisited...
I apologize for our delay in response. In order to help surface current, unresolved issues, we are closing tickets prior to February 29. Please reopen your ticket if you are continuing to experience this issue. Thank you!