equinox
equinox copied to clipboard
Feature Request: add attention utility functions like flax
It would be nice to split up the the MultiheadAttention Module code into utility functions for calculating attention weights and multihead attention without projection layers. Flax does this -- but would be nice to have it here as well.
https://github.com/google/flax/blob/main/flax/linen/attention.py#L40-L187
Sounds reasonable. I'd be happy to accept a PR doing this.
Hey, I'd be interested in picking this up as a first issue if that's alright!
Go for it!