hls4ml icon indicating copy to clipboard operation
hls4ml copied to clipboard

Add Multi-Head Attention support for Vitis

Open rianbrooksflynn opened this issue 11 months ago • 4 comments

Description

This PR adds support for Multi-Head Attention using either Keras or PyTorch with the Vitis backend in io_parallel mode.

Tests have been added for both Keras and Pytorch parsing.

Credit is due to @Ethan0Jiang and @LostEcho365 (Zhixing Jiang and Dennis Yin) for their original implementation and Keras parsing support; my contributions were implementing PyTorch support and adding unit tests. (Here's a link to their pre-print.) The original code authors have given permission for their code to be merged into hls4ml.

There are some important notes for PyTorch (TODO: add documentation to this effect):

  • Need to set batch_first=True when instantiating nn.MultiheadAttention so that the inputs match up ((batch_size, seq_len, embed_dim) instead of (seq_len, batch_size, embed_dim)).
  • Need to set channels_last_conversion='off' when calling config_from_pytorch_model() since batch-first PyTorch and Keras use the same input shape.
  • Keras lets you call MultiHeadAttention using just two inputs (or even just one input for self-attention), but PyTorch insists that you give it all three of query, key, and value; hls4ml currently only supports the case where key and value are the same; thus, you must give PyTorch the same data for the second input and the third input.

Type of change

  • [x] New feature (non-breaking change which adds functionality)
  • [x] A new research paper code implementation

Tests

Two unit tests added: test/pytest/test_multiheadattention.py and test/pytest/test_multiheadattention_pytorch.py

Checklist

  • [x] I have read the guidelines for contributing.
  • [x] I have commented my code, particularly in hard-to-understand areas.
  • [ ] I have made corresponding changes to the documentation.
  • [x] My changes generate no new warnings.
  • [ ] I have installed and run pre-commit on the files I edited or added.
  • [x] I have added tests that prove my fix is effective or that my feature works.

rianbrooksflynn avatar Jan 14 '25 14:01 rianbrooksflynn

Thank you so much for merging it to the main!

Ethan0Jiang avatar Jan 14 '25 14:01 Ethan0Jiang

pre-commit.ci autofix

rianbrooksflynn avatar Jan 14 '25 14:01 rianbrooksflynn

Hi @rianbrooksflynn! Great work on the Multi-Head Attention implementation.

Could you consider adding usage examples (e.g., examples/multihead_attention_keras.py and examples/multihead_attention_pytorch.py) to help users understand how to properly use this feature?

The examples could demonstrate the important PyTorch requirements you mentioned (batch_first=True, channels_last_conversion='off', same key/value inputs) and basic Keras usage.

Thanks!

mahadkhaliq avatar Sep 25 '25 22:09 mahadkhaliq

As far as I can tell, masking (e.g causal masking) is not supported in this. Would it be ok if I build on top of this PR and add it?

porridgewithraisins avatar Sep 30 '25 08:09 porridgewithraisins