xformers icon indicating copy to clipboard operation
xformers copied to clipboard

[Feature Request] Sparse tensor support for `attn_bias` in `memory_efficient_attention`

Open davidbuterez opened this issue 8 months ago • 1 comments

🚀 Feature

I think that it would be very helpful to allow sparse tensors to be used as masks (attn_bias) in the memory_efficient_attention function.

Motivation

As the documentation notes:

The most common use case is for an attention bias is to contain only zeros and negative infinities, which forms a mask so that some queries only attend to some keys.

For custom attention masks, it is very common to have only two values: negative infinity and zeros, or 0s and 1s if using boolean values. The shape of the mask is dictated by the sequence length. For long sequences (e.g. 50000), the shape of the mask should be [*, 50000, 50000]. However, this tensor can be extremely sparse. In my use cases, it is common that only about x out of x^2 values are non-zero.

The difference in memory between dense and sparse tensors is illustrated by the following example:

t_sparse = torch.load('tensor_sparse.pt')

print(t_sparse.dtype)
> torch.float32

print(t_sparse.shape)
> torch.Size([100000, 100000])

print(t_sparse.coalesce().values().shape)
> torch.Size([99999])

torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() // 2 ** 20
print(f"Peak memory use: {max_memory}MB")
> Peak memory use: 6MB

t_sparse_dense = t_sparse.to_dense()

---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
<ipython-input-6-a8a6660d8ae5> in <cell line: 1>()
----> 1 t_sparse_dense = t_sparse.to_dense()

OutOfMemoryError: CUDA out of memory. Tried to allocate 37.25 GiB (GPU 0; 14.75 GiB total capacity; 1.91 MiB already allocated; 14.14 GiB free; 22.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Pitch

The main memory bottleneck for my current use case is the allocation of custom masks due to the quadratic scaling. However, these masks contain only two values and are usually very sparse. I think that it would be great and impactful to be able to use a sparse format for the attention bias. Similar functionality is supposed to be available in vanilla PyTorch (https://github.com/pytorch/pytorch/pull/104042). However, it currently does not work as intended and I noticed better performance for the xFormers implementation of memory efficient attention.

Alternatives

This feature is implemented in PyTorch for torch.nn.functional.scaled_dot_product_attention (https://github.com/pytorch/pytorch/pull/104042). However, it currently does not work correctly (details in the comments).

davidbuterez avatar Oct 17 '23 19:10 davidbuterez