Add Multi-Head Attention support for Vitis
Description
This PR adds support for Multi-Head Attention using either Keras or PyTorch with the Vitis backend in io_parallel mode.
Tests have been added for both Keras and Pytorch parsing.
Credit is due to @Ethan0Jiang and @LostEcho365 (Zhixing Jiang and Dennis Yin) for their original implementation and Keras parsing support; my contributions were implementing PyTorch support and adding unit tests. (Here's a link to their pre-print.) The original code authors have given permission for their code to be merged into hls4ml.
There are some important notes for PyTorch (TODO: add documentation to this effect):
- Need to set
batch_first=Truewhen instantiatingnn.MultiheadAttentionso that the inputs match up ((batch_size, seq_len, embed_dim)instead of(seq_len, batch_size, embed_dim)). - Need to set
channels_last_conversion='off'when callingconfig_from_pytorch_model()since batch-first PyTorch and Keras use the same input shape. - Keras lets you call MultiHeadAttention using just two inputs (or even just one input for self-attention), but PyTorch insists that you give it all three of
query,key, andvalue; hls4ml currently only supports the case wherekeyandvalueare the same; thus, you must give PyTorch the same data for the second input and the third input.
Type of change
- [x] New feature (non-breaking change which adds functionality)
- [x] A new research paper code implementation
Tests
Two unit tests added: test/pytest/test_multiheadattention.py and test/pytest/test_multiheadattention_pytorch.py
Checklist
- [x] I have read the guidelines for contributing.
- [x] I have commented my code, particularly in hard-to-understand areas.
- [ ] I have made corresponding changes to the documentation.
- [x] My changes generate no new warnings.
- [ ] I have installed and run
pre-commiton the files I edited or added. - [x] I have added tests that prove my fix is effective or that my feature works.
Thank you so much for merging it to the main!
pre-commit.ci autofix
Hi @rianbrooksflynn! Great work on the Multi-Head Attention implementation.
Could you consider adding usage examples (e.g., examples/multihead_attention_keras.py and examples/multihead_attention_pytorch.py) to help users understand how to properly use this feature?
The examples could demonstrate the important PyTorch requirements you mentioned (batch_first=True, channels_last_conversion='off', same key/value inputs) and basic Keras usage.
Thanks!
As far as I can tell, masking (e.g causal masking) is not supported in this. Would it be ok if I build on top of this PR and add it?