addons
addons copied to clipboard
Proposed feature: Multi-Head Attention with O(sqrt(N)) memory
The recent paper by Rabe, Staats, et. al on the Google Research team "Self-attention Does Not Need O(n2) Memory" provides a numerically stable implementation of multi-headed attention that requires O(√n) memory to compute by not storing the full self-attention matrix in memory. There are some tricks using chunked computation (rather than all individually) to make this efficient on a TPU, and gradient checkpointing is used so that memory doesn't become quadratic during backpropagation. The chunk size is configurable because the "best" tradeoff for memory vs. computation speed may depend on the device.
This allows for much larger sequence lengths with transformer models and/or reduced memory usage for existing models.
This should have the same (or as similar as possible) of an API as the multi-headed attention layer in upstream Tensorflow.
Relevant information
- Are you willing to contribute it (yes/no): yes, I've begun work on my branch
If you wish to contribute, then read the requirements for new contributions in
CONTRIBUTING.md - Are you willing to maintain it going forward? (yes/no): I have a limited number of total hours per week I can dedicate, but could maintain it or help to maintain it.
- Is there a relevant academic paper? (if so, where): Yes, https://arxiv.org/pdf/2112.05682.pdf
- Does the relavent academic paper exceed 50 citations? (yes/no): No, it's a preprint so hasn't been cited yet.
- Is there already an implementation in another framework? (if so, where): yes, the reference implementation is in the paper in jax. There's also an unofficial PyTorch implementation here.
- Was it part of tf.contrib? (if so, where): no
Which API type would this fall under (layer, metric, optimizer, etc.) This would be a new layer. Who will benefit with this feature?
I see two main use cases:
- NLP researchers, particularly those training large transformer-based models on hardware that's memory constrained.
- ML Engineers who primarily wish to use this to deploy transformer models on memory-constrained devices. The vanilla MultiHeadAttention layers would be swapped with these layers, and the weights could be transferred. Then both layers should give the same results at inference time.
TODO: Add a test to make sure this layer is tflite-able?
Generally we require a little bit of stabilization in research papers so we wait for >50 citations.
You could also try to check if they are interested at https://github.com/keras-team/keras-nlp
Understandable, thanks! It looks like early days for keras-nlp so I think I'll just keep this on my branch for now, but I'll check back in on them later (and here again when/if the paper gets more citations).
TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision: TensorFlow Addons Wind Down
Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA: Keras Keras-CV Keras-NLP