addons icon indicating copy to clipboard operation
addons copied to clipboard

Proposed feature: Multi-Head Attention with O(sqrt(N)) memory

Open singularperturbation opened this issue 3 years ago • 3 comments

The recent paper by Rabe, Staats, et. al on the Google Research team "Self-attention Does Not Need O(n2) Memory" provides a numerically stable implementation of multi-headed attention that requires O(√n) memory to compute by not storing the full self-attention matrix in memory. There are some tricks using chunked computation (rather than all individually) to make this efficient on a TPU, and gradient checkpointing is used so that memory doesn't become quadratic during backpropagation. The chunk size is configurable because the "best" tradeoff for memory vs. computation speed may depend on the device.

This allows for much larger sequence lengths with transformer models and/or reduced memory usage for existing models.

This should have the same (or as similar as possible) of an API as the multi-headed attention layer in upstream Tensorflow.

Relevant information

  • Are you willing to contribute it (yes/no): yes, I've begun work on my branch If you wish to contribute, then read the requirements for new contributions in CONTRIBUTING.md
  • Are you willing to maintain it going forward? (yes/no): I have a limited number of total hours per week I can dedicate, but could maintain it or help to maintain it.
  • Is there a relevant academic paper? (if so, where): Yes, https://arxiv.org/pdf/2112.05682.pdf
  • Does the relavent academic paper exceed 50 citations? (yes/no): No, it's a preprint so hasn't been cited yet.
  • Is there already an implementation in another framework? (if so, where): yes, the reference implementation is in the paper in jax. There's also an unofficial PyTorch implementation here.
  • Was it part of tf.contrib? (if so, where): no

Which API type would this fall under (layer, metric, optimizer, etc.) This would be a new layer. Who will benefit with this feature?

I see two main use cases:

  • NLP researchers, particularly those training large transformer-based models on hardware that's memory constrained.
  • ML Engineers who primarily wish to use this to deploy transformer models on memory-constrained devices. The vanilla MultiHeadAttention layers would be swapped with these layers, and the weights could be transferred. Then both layers should give the same results at inference time.

TODO: Add a test to make sure this layer is tflite-able?

singularperturbation avatar Jan 24 '22 00:01 singularperturbation

Generally we require a little bit of stabilization in research papers so we wait for >50 citations.

bhack avatar Jan 24 '22 00:01 bhack

You could also try to check if they are interested at https://github.com/keras-team/keras-nlp

bhack avatar Jan 24 '22 00:01 bhack

Understandable, thanks! It looks like early days for keras-nlp so I think I'll just keep this on my branch for now, but I'll check back in on them later (and here again when/if the paper gets more citations).

singularperturbation avatar Jan 24 '22 03:01 singularperturbation

TensorFlow Addons is transitioning to a minimal maintenance and release mode. New features will not be added to this repository. For more information, please see our public messaging on this decision: TensorFlow Addons Wind Down

Please consider sending feature requests / contributions to other repositories in the TF community with a similar charters to TFA: Keras Keras-CV Keras-NLP

seanpmorgan avatar Mar 01 '23 04:03 seanpmorgan